• Rebuilding GPT2 inference in ~500 lines of code

  • ¶

    Hey! Like most engineers I’ve been looking at the latest LLMs and wondering how they actually work.

    I watched a couple videos to try to understand them (Karpathy’s video series are pretty good) but they’re mostly taught by machine learning PhDs so they tend to gloss over the non-mathy part of the problems.

    To try to break it down further I rebuilt GPT2 from scratch . This file loads the open source GPT2 weights and applies all the inference math until we get to the results. I’m also going to add explanations for everything that happens along the way.

    You can find the full code on GitHub.

    By the way, if you have any questions or comments, hit me up at hello at khamidou.com!

    A word of warning

    It’s written in Swift because:

    1. I’m using an old Macbook Air
    2. Apple recently added hardware-accelerated routines to make inference and training faster and I wanted to try them.

    It’s very limited Swift though, and you don’t really need to know the language to follow this.

    Running a model

    When an ML person talks about running a model, what they actually mean is applying the mathematical operations defined in the model weights to input data.

    In a way, the model is like a program that your inference engine is going to interpret. The weights are like the code (or bytecode, if we’re going by that interpreter metaphor).

  • ¶

    A model is made of multiple layers stacked one after the other. In the case of GPT-models, we actually take the input text, convert it to something computers can understand (tensors, which are basically vectors) and then pass that data through several blocks of transformers.

  • ¶

    Below you can see a diagram of a basic transformer-based model which is slightly different from GPT2 but close enough to get the gist: gpt2 architecture

  • ¶

    Transformers have been super popular recently because they’re able to give context to a machine learning model. There are many somewhat nebulous explanations of transformers, I like this one from Bertrand Serlet.

  • ¶

    We’ll get to the thorny parts of the transformer later, but for now let’s remember it’s a black box that can learn context from sequences and is able to decide which part of a stream are important.

  • ¶

    Loading a model

    The GPT-2 weights are available on huggingface. They’re in a very pytorch-specific format called “SafeTensors” but obviously we can not load that directly from our Swift program.

    To be able to load the weights, I cheated a bit and wrote a Python script that loads the Huggingface version of GPT-2 then exports the weights to a binary file.

    It outputs two files:

    • a binary file with the raw weights, concatenated together
    • a manifest file that describes the type of each weight and which layer it belongs to.

    This is what the manifest looks like:

    {
      “model_id”: “gpt2”,
      “dtype”: “float32”,
      “config”: {
        “vocab_size”: 50257,
        “n_layer”: 12,
        “n_head”: 12,
        “n_embd”: 768,
        “n_positions”: 1024,
        “bos_token_id”: 50256,
        “eos_token_id”: 50256,
      },
      “tensors”: [
        {
          “name”: “transformer.wte.weight”,
          “layer_type”: “token_embedding”,
          “shape”: [
            50257,
            768
          ],
          “dtype”: “float32”,
          “byte_offset”: 0,
          “nbytes”: 154389504,
          “sha256”: “e182e433b37dbdb47448e0413c840edf6965113c1fc8048c63b69795d8cf875a”
        },
        {
          “name”: “transformer.wpe.weight”,
          “layer_type”: “positional_embedding”,
          “shape”: [
            1024,
            768
          ],
          “dtype”: “float32”,
          “byte_offset”: 154389504,
          “nbytes”: 3145728,
          “sha256”: “29c69587e2af826b7c159c38620a32e549a925e6d9a0dc37cb563f377f5be772”
        },
        {
          “name”: “transformer.h.0.ln_1.weight”,
          “layer_type”: “pre_attn_layernorm”,
          “shape”: [
            768
          ],
          “dtype”: “float32”,
          “byte_offset”: 157535232,
          “nbytes”: 3072,
          “sha256”: “87c92a5f0409ab2a8f2b92a9ebbb62ab9215590402d38929402c84357d53e4ae”
        },
    
    etc…

    So basically, the manifest describes the type of each tensor and where they start and end, which means we can iterate through the manifest and transform data until we get to the result – just like an interpreter would do!

    Let’s jump into the actual code now.

  • ¶

    The actual code

    import Foundation
    import Metal
    import MetalPerformanceShadersGraph
  • ¶

    To make sure the code runs as fast as possible, we’ll be using Apple’s Metal Performance Shaders Graph (MPSGraph) framework. It takes a second to wrap your head around it, but basically you have to define an execution graph which can contain multiple operations. Once you’re done with that you send this to MPS to compile and convert it to optimized metal instructions, which run on the Mac’s GPU.

    let device = MTLCreateSystemDefaultDevice()!
    let graph = MPSGraph()
    let graphDevice = MPSGraphDevice(mtlDevice: device)
    
    let arguments = CommandLine.arguments
    guard arguments.count == 4 else {
      print("Usage: \(arguments[0]) <manifest.json> <weights.bin> 'prompt'")
      print(
        "Example: \(arguments[0]) tinygpt2.manifest.json tinygpt2.bin \"hello what is your name? \"")
      exit(1)
    }
    
    let manifestPath = arguments[1]
    let weightsPath = arguments[2]
    let textToEncode = arguments[3]
  • ¶

    We’re going to parse the manifest file first. To do this we define a struct, exactly like you would in golang.

    struct Manifest: Codable {
      let tensors: [TensorInfo]
      let config: Config
    
      struct Config: Codable {
        let vocabSize: Int
        let nLayer: Int
        let nHeads: Int
        let nEmbd: Int
        let nPositions: Int
  • ¶

    Swift has this weird thing where you define a CodingKeys enum to specify which fields map to what.

        enum CodingKeys: String, CodingKey {
          case vocabSize = "vocab_size"
          case nLayer = "n_layer"
          case nHeads = "n_head"
          case nEmbd = "n_embd"
          case nPositions = "n_positions"
        }
      }
    
      enum CodingKeys: String, CodingKey {
        case tensors = "tensors"
        case config = "config"
      }
    
      struct TensorInfo: Codable {
        let name: String
        let layerType: String
        let shape: [Int]
        let dataType: String
        let byteOffset: Int
        let nBytes: Int
    
        enum CodingKeys: String, CodingKey {
          case name = "name"
          case layerType = "layer_type"
          case shape = "shape"
          case dataType = "dtype"
          case byteOffset = "byte_offset"
          case nBytes = "nbytes"
        }
      }
    }
  • ¶

    Another swift surprise – the guard statement lets you check that a condition is true, and error out if it’s not. We use it to check both manifest and weights binary file exist.

    guard let manifestData = try? Data(contentsOf: URL(fileURLWithPath: manifestPath)) else {
      fatalError("Failed to load manifest at \(manifestPath)")
    }
    
    guard let parsedManifest = try? JSONDecoder().decode(Manifest.self, from: manifestData) else {
      fatalError("Failed to parse manifest JSON")
    }
    
    print("Manifest config: \(parsedManifest.config)")
    
    guard
      let weightsData = try? Data(contentsOf: URL(fileURLWithPath: weightsPath), options: .alwaysMapped)
    else {
      fatalError("Failed to load weights from \(weightsPath)")
    }
  • ¶

    These files are used by the GPT-2 tokenizer, which is the code that transforms a list of words into numbers. I want to focus on the transformer implementation so we’ll gloss over this but this video is a good breakdown of how it works.

    For now, just assume that we have a tokenizer object with both an encode() and decode() methods to convert words to token and vice versa.

    let encoderURL = URL(fileURLWithPath: "encoder.json")
    let bpeURL = URL(fileURLWithPath: "vocab.bpe")
    let tokenizer = try GPT2Tokenizer.load(encoderJSON: encoderURL, mergesTXT: bpeURL)
  • ¶

    Convert our input text to tokens. This will be a list like [15496, 616, 1438, 318]

    var tokens = tokenizer.encode(textToEncode)
  • ¶

    Now let’s run the model in a loop and pass in the results back at every turn.

    while tokens.count < 30 {
      print("Current tokens: \(tokens.count), generating more...")
      let nextToken = runLLMInALoop(tokens: tokens)
      tokens.append(nextToken)
      let decodedText = tokenizer.decode(tokens)
      print("Generated text: \(decodedText)")
    }
    
    func runLLMInALoop(tokens: [Int32]) -> Int32 {
  • ¶

    Without getting into too many details about Metal, there’s a strict separation between GPU and CPU memory, even though they end up sharing the same RAM behind the scenes. This means we have to create GPU-specific buffers if we want to send data to the GPU.

      let tokensBuf = device.makeBuffer(
        bytes: tokens,
        length: tokens.count * MemoryLayout<Int32>.stride,
        options: [.storageModeShared])!
  • ¶

    MPSGraphTensorData is Metal Performance Shader’s base data structure for tensor manipulation. It’s basically an array of numbers with a data type (int32, float32, etc.) as well as a shape.

      let tokensTensorData = MPSGraphTensorData(
        tokensBuf,
        shape: [NSNumber(value: 1), NSNumber(value: tokens.count)],
        dataType: .int32)
  • ¶

    The tensor where we will accumulate data in at every step.

      var workingTensorData: MPSGraphTensorData = MPSGraphTensorData(
        device: graphDevice,
        data: Data(),
        shape: [NSNumber(value: tokens.count), NSNumber(value: parsedManifest.config.nEmbd)],
        dataType: .float32)
  • ¶

    A utility function to create a tensor based on the data in the manifest. Don’t come at me for defining a function within a function, this is just to make it clearer to the reader.

      func loadWeights(from manifest: Manifest, at index: Int, using weightsData: Data) -> MPSGraphTensorData {
        guard manifest.tensors.indices.contains(index) else {
          fatalError("Index out of bounds for tensors array")
        }
    
        let tensorInfo = manifest.tensors[index]
        let data = weightsData.subdata(
          in: tensorInfo.byteOffset..<(tensorInfo.byteOffset + tensorInfo.nBytes))
    
        return MPSGraphTensorData(
          device: graphDevice,
          data: data,
          shape: tensorInfo.shape.map { NSNumber(value: $0) },
  • ¶

    Note that we always use Float32s in this file. Float16s are faster but there’s some operations that don’t work on them, so for simplicity we use Float32s.

          dataType: .float32
        )
      }
  • ¶

    A transformer block includes a residual step – that means that we have to sometimes add the unaltered input of the block back into the computed result. To do that we have to save the input to the block in this tensor.

      var residualTensorData: MPSGraphTensorData? = nil
    
      for (i, layer) in parsedManifest.tensors.enumerated() {
  • ¶

    Encoding the data

    The very first step of GPT-2 is to take the tokens and map them to vectors. We map the individual words to a stack of vectors, one per word. To do this we use the model’s embedding matrix. It maps individual tokens (i.e numbers between 0 to 50k) to a vector (of length 768 for GPT-2)

        if layer.name == "transformer.wte.weight" {
  • ¶

    Create a tensor data object for the embedding weights, then load it.

          let wteTensorData = loadWeights(
            from: parsedManifest,
            at: i,
            using: weightsData)
  • ¶

    Note: graph.placeholder is simply a convention with MPS to specify some input data that we will provide during the run. We need placeholders both for the embedding matrix but also for our individual token ids.

          let E = graph.placeholder(
            shape: [
              NSNumber(value: parsedManifest.config.vocabSize),
              NSNumber(value: parsedManifest.config.nEmbd),
            ],
            dataType: .float32,
            name: "EmbeddingTable")
    
          let ids = graph.placeholder(
            shape: [NSNumber(value: 1), NSNumber(value: tokens.count)],
            dataType: .int32,
            name: "TokenIDs")
  • ¶

    Gather is simply a batched lookup operation – MPS will go through the list of tokens and return the value of row n in the embedding matrix.

          let embeds = graph.gather(
            withUpdatesTensor: E,
            indicesTensor: ids,
            axis: 0,
            batchDimensions: 0,
            name: "embeds")
  • ¶

    Finally we actually run the operation graph we defined. We have to pass in MPSGraphTensorData objects for each input and specify the outputs we care about.

    Behind the scenes, Apple’s framework will run these operations on the GPU.

          let outTD = graph.run(
            feeds: [E: wteTensorData, ids: tokensTensorData],
            targetTensors: [embeds],
            targetOperations: nil,
          )
  • ¶

    Finally we fetch that data and save it into workingTensorData to pass into the next step.

          workingTensorData = outTD[embeds]!
        } else if layer.name == "transformer.wpe.weight" {
  • ¶

    The second step is to add some sort of positional information to the tokens. This is because GPT-2 also uses word order information as a signal. To do this we load a table of learned position features from the weight and fetch values for each individual token. Then we simply sum these features with the existing vector we have.

          let P = graph.placeholder(
            shape: [
              NSNumber(value: parsedManifest.config.nPositions),
              NSNumber(value: parsedManifest.config.nEmbd),
            ],
            dataType: .float32,
            name: "WPE"
          )
    
          let wpeData = weightsData.subdata(in: layer.byteOffset..<(layer.byteOffset + layer.nBytes))
          let Pdata = MPSGraphTensorData(
            device: graphDevice,
            data: wpeData,
            shape: [
              NSNumber(value: parsedManifest.config.nPositions),
              NSNumber(value: parsedManifest.config.nEmbd),
            ],
            dataType: .float32
          )
    
          let seq = tokens.count
          let posI32: [Int32] = (0..<seq).map(Int32.init)
  • ¶

    We have to do this because graph.constant() expects Data, not an array of Int32

          let posIdsData = posI32.withUnsafeBufferPointer { Data(buffer: $0) }
          let posIds = graph.constant(
            posIdsData,
            shape: [NSNumber(value: seq)],
            dataType: .int32)
    
          let posEmbeds = graph.gather(
            withUpdatesTensor: P,
            indicesTensor: posIds,
            axis: 0,
            batchDimensions: 0,
            name: "posEmbeds")
    
          let embedPlaceholder = graph.placeholder(
            shape: workingTensorData.shape,
            dataType: .float32,
            name: "embedPos")
    
          let addedTensor = graph.addition(embedPlaceholder, posEmbeds, name: "embedsWithPos")
  • ¶

    We get an extra dimension from the gather so we squeeze it (remove it) to make the tensor be of shape [n_tokens, 768]

          let squeezeLast = graph.squeeze(
            addedTensor,
            axes: [NSNumber(value: 0)],
            name: "squeeze_last")
    
          let outTD = graph.run(
            feeds: [embedPlaceholder: workingTensorData, P: Pdata],
            targetTensors: [squeezeLast],
            targetOperations: nil)
    
          workingTensorData = outTD[squeezeLast]!
        } else if layer.layerType.hasSuffix("_layernorm") && layer.name.hasSuffix(".weight") {
  • ¶

    Layer normalization

    The next step is to implement layer normalization. It’s used a lot in GPT-2 as we normalize layers at the start and end of every transformer block (pre-attention normalization and pre-Multi Layer Perceptron normalization respectfully)

    This blog post is a good introduction to why LayerNorm is useful and how it’s implemented.

  • ¶

    We are starting a transformer block. Transformer blocks have residual connections, so we need to save the working tensor data to add it to the result later.

          if layer.layerType == "pre_attn_layernorm" {
            residualTensorData = workingTensorData
          }
  • ¶

    LayerNorm works with learned scaling and shifting parameters, so we need to load them from the opensource weights.

          let lnScaling = graph.placeholder(
            shape: [NSNumber(value: parsedManifest.config.nEmbd)],
            dataType: .float32,
            name: "lnScaling")
          let lnShifting = graph.placeholder(
            shape: [NSNumber(value: parsedManifest.config.nEmbd)],
            dataType: .float32,
            name: "lnShifting")
  • ¶

    Check that the next tensor in the manifest is the shifting factor.

          guard parsedManifest.tensors.indices.contains(i + 1) else {
            fatalError("Expected next tensor to be bias for layer norm")
          }
  • ¶

    ..then load the weights

          let lnScalingTD = loadWeights(
            from: parsedManifest,
            at: i,
            using: weightsData)
    
          let lnShiftingTD = loadWeights(
            from: parsedManifest,
            at: i + 1,
            using: weightsData)
    
          let tempPlaceholder = graph.placeholder(
            shape: workingTensorData.shape,
            dataType: .float32,
            name: "TempPlaceholder")
  • ¶

    Finally we run LayerNorm If you scroll down to the bottom of this file you’ll find a basic implementation of it.

          let lnResult = layerNorm(graph: graph, x: tempPlaceholder, gamma: lnScaling, beta: lnShifting)
    
          let out = graph.run(
            feeds: [
              lnScaling: lnScalingTD, lnShifting: lnShiftingTD, tempPlaceholder: workingTensorData,
            ],
            targetTensors: [lnResult],
            targetOperations: nil)
    
          workingTensorData = out[lnResult]!
        } else if layer.layerType.hasSuffix("_layernorm") && layer.name.hasSuffix(".bias") {
          continue  // Skip the bias tensor, we already processed it
        } else if layer.layerType == "attn_qkv_proj" && layer.name.hasSuffix(".weight") {
  • ¶

    Alright, now let’s look into the most complicated part – attention. I really recommend reading this post which breaks down the intuition behind attention and transformers. However, let me try to explain it with an analogy that no serious ML researcher would make: attention is kind of like a search engine. You have three vectors:

    1. the query vector, which encodes what we’re looking for
    2. the key vector, contains information about the data’s relevance to the current query. Relevance is computed by taking the dot product of the query and key vectors.
    3. the value vector represents the actual data associated with each key.

    Each of these vectors are computed by multiplying the input data with learned weight matrices (which we load from the weights file).

  • ¶

    In our case these vectors are packed together in a single matrix multiplication for efficiency.

          let wTensorData = loadWeights(
            from: parsedManifest,
            at: i,
            using: weightsData)
    
          let weightsTensorPlaceholder = graph.placeholder(
            shape: wTensorData.shape,
            dataType: .float32,
            name: "attnWeights")
    
          let inputPlaceholder = graph.placeholder(
            shape: workingTensorData.shape,
            dataType: .float32,
            name: "attnInput")
    
          let projIn = graph.matrixMultiplication(
            primary: inputPlaceholder,
            secondary: weightsTensorPlaceholder,
            name: "attnProj")
  • ¶

    Add bias if it exists

          guard
            parsedManifest.tensors.indices.contains(i + 1),
            parsedManifest.tensors[i + 1].layerType == "attn_qkv_proj",
            parsedManifest.tensors[i + 1].name.hasSuffix(".bias")
          else {
            fatalError("Expected next tensor to be bias for attention projection")
          }
    
          let biasTensorData = loadWeights(from: parsedManifest, at: i + 1, using: weightsData)
    
          let biasTensorPlaceholder = graph.placeholder(
            shape: biasTensorData.shape,
            dataType: .float32,
            name: "attnBias")
    
          let projInWithBias = graph.addition(projIn, biasTensorPlaceholder, name: "attnProjWithBias")
  • ¶

    Apple added a fused Scaled Dot-Product Attention operation in iOS 18 / macOS 15 (see docs). This lets us run attention in a single operation, which is way faster than doing it manually. It’s a bit tricky to use because it expects the input tensors to be in a specific shape, so we have to do some reshaping and transposing.

          let B = 1  // batch size, we assume 1 for simplicity
          let T = tokens.count  // sequence length
          let H = parsedManifest.config.nHeads  // number of attention heads
          let D = parsedManifest.config.nEmbd   // D is the model's embedding dimension
          let d = D / H  // dimension per head
  • ¶

    We have to define a masking matrix to hide tokens in the future. Otherwise the result of attention will be wrong because each token will be able to “see” future tokens.

          let neg: Float = -1e9
          var maskDataArray = [Float](repeating: 0, count: T * T)
          for i in 0..<T {
            for j in (i + 1)..<T { maskDataArray[i * T + j] = neg }
          }
    
          let maskPlaceholder = graph.placeholder(
            shape: [1, 1, T, T].map(NSNumber.init),
            dataType: .float32,
            name: "attnMask"
          )
    
          let maskTensorData = MPSGraphTensorData(
            device: graphDevice,
            data: Data(bytes: &maskDataArray, count: maskDataArray.count * MemoryLayout<Float>.size),
            shape: [NSNumber(value: 1), NSNumber(value: 1), NSNumber(value: T), NSNumber(value: T)],
            dataType: .float32
          )
    
          let q = graph.sliceTensor(
            projInWithBias, dimension: -1, start: 0, length: D, name: "attention_q")
          let k = graph.sliceTensor(
            projInWithBias, dimension: -1, start: D, length: D, name: "attention_k")
          let v = graph.sliceTensor(
            projInWithBias, dimension: -1, start: 2 * D, length: D, name: "attention_v")
  • ¶

    Our input tensors are of shape [B,T,D] but the attention operation expects them to be [B,H,T,d] (batch size, number of heads, sequence length, dimension per head).

          func toBHTD(_ t: MPSGraphTensor) -> MPSGraphTensor {
            let bthd = graph.reshape(t, shape: [B, T, H, d].map(NSNumber.init), name: nil)
            return graph.transposeTensor(bthd, dimension: 1, withDimension: 2, name: nil)
          }
    
          let q_bhtd = toBHTD(q)
          let k_bhtd = toBHTD(k)
          let v_bhtd = toBHTD(v)
  • ¶

    Finally we can run attention!

          let scale = 1.0 / sqrt(Float(d))
          let attn = graph.scaledDotProductAttention(
            query: q_bhtd,
            key: k_bhtd,
            value: v_bhtd,
            mask: maskPlaceholder,
            scale: scale,
            name: "attention_sdpa"
          )
  • ¶

    Extract the output of attention and reshape it back to [B,T,D]

          let attn_bthd = graph.transposeTensor(attn, dimension: 1, withDimension: 2, name: nil)
          let attn_btd = graph.reshape(attn_bthd, shape: [B, T, D].map(NSNumber.init), name: nil)
  • ¶

    Once we’ve the attention we need to compute the output projection. The output projection is another learned weight matrix that we multiply the vectors with.

          guard
            parsedManifest.tensors.indices.contains(i + 2)
              && parsedManifest.tensors[i + 2].layerType == "attn_out_proj"
              && parsedManifest.tensors[i + 2].name.hasSuffix(".weight")
              && parsedManifest.tensors.indices.contains(i + 3)
              && parsedManifest.tensors[i + 3].name.hasSuffix(".bias")
          else {
            print("i + 2", parsedManifest.tensors[i + 2])
            print("i + 3", parsedManifest.tensors[i + 3])
            fatalError("Expected next tensor to be weight for output projection")
          }
  • ¶

    We assume the next tensor is the output projection weights

          let outputWeightsData = loadWeights(
            from: parsedManifest,
            at: i + 2,
            using: weightsData)
    
          let outputWeightsPlaceholder = graph.placeholder(
            shape: outputWeightsData.shape,
            dataType: .float32,
            name: "attnOutputWeights")
    
          let outputBiasData = loadWeights(
            from: parsedManifest,
            at: i + 3,
            using: weightsData)
    
          let outputBiasPlaceholder = graph.placeholder(
            shape: outputBiasData.shape,
            dataType: .float32,
            name: "attnOutputBias")
    
          let projection = graph.addition(
            graph.matrixMultiplication(
              primary: attn_btd,
              secondary: outputWeightsPlaceholder,
              name: "attnProjection"), outputBiasPlaceholder, name: "attnProjectionWithBias")
    
          guard residualTensorData != nil else {
            fatalError("Residual tensor data is nil, expected to be set before attention layer")
          }
  • ¶

    We also need to add a residual connection here. The residual is the input to the attention block.

          let residualPlaceholder = graph.placeholder(
            shape: residualTensorData!.shape,
            dataType: .float32,
            name: "attnResidual")
    
          let residualAdded = graph.addition(
            projection, residualPlaceholder, name: "attnProjectionWithResidual")
    
          let out = graph.run(
            feeds: [
              inputPlaceholder: workingTensorData,
              weightsTensorPlaceholder: wTensorData,
              biasTensorPlaceholder: biasTensorData,
              maskPlaceholder: maskTensorData,
              outputWeightsPlaceholder: outputWeightsData,
              outputBiasPlaceholder: outputBiasData,
              residualPlaceholder: residualTensorData!,
            ],
            targetTensors: [residualAdded, attn_bthd],
            targetOperations: nil)
    
          workingTensorData = out[residualAdded]!
  • ¶

    Update residual for next layer This is because the residual for the MLP is the output of the attention. See https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/model.py#L105

          residualTensorData = workingTensorData
        } else if layer.layerType.hasSuffix("mlp_fc_in") && layer.name.hasSuffix(".weight") {
  • ¶

    Finally the last step in the transformer block is piping the output into a Multi Layer Perceptron (MLP). The MLP is made of two linear transformations with a non-linear activation function in between. This is probably the most straightforward part of the transformer to implement 😅

    Note that this is broken into two parts: the input projection and the output projection. The output projection is implemented in the next else if block.

          let wTensorData = loadWeights(
            from: parsedManifest,
            at: i,
            using: weightsData)
    
          let weightsTensorPlaceholder = graph.placeholder(
            shape: wTensorData.shape,
            dataType: .float32,
            name: "mlpWeights")
  • ¶

    As always, check that the next tensor is the bias

          guard
            parsedManifest.tensors.indices.contains(i + 1)
              && parsedManifest.tensors[i + 1].layerType == "mlp_fc_in"
              && parsedManifest.tensors[i + 1].name.hasSuffix(".bias")
          else {
            fatalError("Expected next tensor to be bias for MLP input projection")
          }
    
          let biasTensorData = loadWeights(
            from: parsedManifest,
            at: i + 1,
            using: weightsData)
    
          let biasTensorPlaceholder = graph.placeholder(
            shape: biasTensorData.shape,
            dataType: .float32,
            name: "mlpBias")
    
          let inputPlaceholder = graph.placeholder(
            shape: workingTensorData.shape,
            dataType: .float32,
            name: "mlpInput")
    
          let projOut = graph.matrixMultiplication(
            primary: inputPlaceholder,
            secondary: weightsTensorPlaceholder,
            name: "mlpProj")
    
          let projOutWithBias = graph.addition(projOut, biasTensorPlaceholder, name: "mlpProjWithBias")
  • ¶

    The activation function for the MLP is GELU. Apple’s MPSGraph does not have a built-in version of it so we have to implement it ourselves. For the sake of simplicity, we use the tanh approximation. More details about the formula for the approximation can be found here See https://en.wikipedia.org/wiki/Gaussian_error_linear_unit#Approximation

          let geluOut = geluTanhApprox(projOutWithBias, graph)
    
          let res = graph.run(
            feeds: [
              inputPlaceholder: workingTensorData,
              weightsTensorPlaceholder: wTensorData,
              biasTensorPlaceholder: biasTensorData,
            ],
            targetTensors: [geluOut],
            targetOperations: nil)
    
          workingTensorData = res[geluOut]!
        } else if layer.layerType.hasSuffix("mlp_fc_out") && layer.name.hasSuffix(".weight") {
  • ¶

    Second part of the MLP is the output projection, which is another linear transformation. Which is just another matrix multiplication with learned weights.

          let wTensorData = loadWeights(
            from: parsedManifest,
            at: i,
            using: weightsData)
    
          let weightsTensorPlaceholder = graph.placeholder(
            shape: wTensorData.shape,
            dataType: .float32,
            name: "mlpOutputWeights")
  • ¶

    Check that the next tensor is the bias

          guard
            parsedManifest.tensors.indices.contains(i + 1)
              && parsedManifest.tensors[i + 1].layerType == "mlp_fc_out"
              && parsedManifest.tensors[i + 1].name.hasSuffix(".bias")
          else {
            fatalError("Expected next tensor to be bias for MLP output projection")
          }
    
          let biasTensorData = loadWeights(
            from: parsedManifest,
            at: i + 1,
            using: weightsData)
    
          let biasTensorPlaceholder = graph.placeholder(
            shape: biasTensorData.shape,
            dataType: .float32,
            name: "mlpOutputBias")
    
          let inputPlaceholder = graph.placeholder(
            shape: workingTensorData.shape,
            dataType: .float32,
            name: "mlpOutputInput")
    
          let projOut = graph.matrixMultiplication(
            primary: inputPlaceholder,
            secondary: weightsTensorPlaceholder,
            name: "mlpOutputProj")
    
          let projOutWithBias = graph.addition(
            projOut, biasTensorPlaceholder, name: "mlpOutputProjWithBias")
  • ¶

    This is the final output of this transformer block, so we can add the residual connection here if it exists

          guard let residualData = residualTensorData else {
            fatalError("Residual tensor data is nil, expected to be set before MLP output projection")
          }
    
          let residualPlaceholder = graph.placeholder(
            shape: residualData.shape,
            dataType: .float32,
            name: "mlpOutputResidual")
    
          let residualAdded = graph.addition(
            projOutWithBias, residualPlaceholder, name: "mlpOutputProjWithResidual")
    
          let res = graph.run(
            feeds: [
              inputPlaceholder: workingTensorData,
              weightsTensorPlaceholder: wTensorData,
              biasTensorPlaceholder: biasTensorData,
              residualPlaceholder: residualData,
            ],
            targetTensors: [residualAdded],
            targetOperations: nil)
    
          workingTensorData = res[residualAdded]!
        } else {
  • ¶

    print(“Skipping layer:”, layer.name)

        }
      }
  • ¶

    Final layer

    Ok we’ve gone through all the layers in the transformer. The final step is to extract logits from the final working tensor.

      let finalResults = graph.placeholder(
        shape: workingTensorData.shape,
        dataType: .float32,
        name: "logitsWeights")
    
      let wteIndex = parsedManifest.tensors.firstIndex(where: { $0.name == "transformer.wte.weight" })!
      let wordTokenEncoding = graph.placeholder(
        shape: parsedManifest.tensors[wteIndex].shape.map { NSNumber(value: $0) },
        dataType: .float32,
        name: "wordTokenEncoding")
    
      let wordTokenWeightsData = loadWeights(
        from: parsedManifest,
        at: wteIndex,
        using: weightsData)
    
      let transposedWordTokenEncoding = graph.transposeTensor(
        wordTokenEncoding, dimension: 0, withDimension: 1, name: "transposedWordTokenEncoding")
    
      let logits = graph.matrixMultiplication(
        primary: finalResults,
        secondary: transposedWordTokenEncoding,
        name: "logits")
  • ¶

    We need to slice the last token from the logits, which is the output we want.

      let lastToken = graph.sliceTensor(
        logits, dimension: 1, start: tokens.count - 1, length: 1, name: "lastToken")
  • ¶

    Finally we apply softmax to convert the logits to probabilities, then pick the top 20 tokens

      let softmaxedLogits = graph.softMax(
        with: lastToken, axis: -1, name: "softmaxed_logits")
    
      let topk = graph.topK(softmaxedLogits, axis: -1, k: 20, name: "topk")
      let topkIndices = topk[1]
    
      let finalOut = graph.run(
        feeds: [
          finalResults: workingTensorData,
          wordTokenEncoding: wordTokenWeightsData,
        ],
        targetTensors: [topkIndices],
        targetOperations: nil)
  • ¶

    Finally we need to read back the result from the GPU to the CPU and return a random token from the top 20.

      var idxs = [Int32](repeating: 0, count: 20)
      finalOut[topkIndices]!.mpsndarray().readBytes(&idxs, strideBytes: nil)
      let randomIndex = Int.random(in: 0..<20)
      let selectedIndex = idxs[randomIndex]
      return selectedIndex
    }
    
    func layerNorm(
      graph: MPSGraph,
      x: MPSGraphTensor,
      gamma: MPSGraphTensor,
      beta: MPSGraphTensor,
      eps: Float = 1e-5
    ) -> (MPSGraphTensor) {
      let mu = graph.mean(of: x, axes: [-1], name: "ln_mu")  // [S]
      let xc = graph.subtraction(x, mu, name: "ln_centered")  // [S,D]
    
      let sq = graph.multiplication(xc, xc, name: "ln_sq")  // [S,D]
      let varT = graph.mean(of: sq, axes: [-1], name: "ln_var")  // [S]
    
      let epsC = graph.constant(1e-5, shape: [1], dataType: .float32)  // broadcasts
      let denom = graph.squareRoot(with: graph.addition(varT, epsC, name: nil), name: "den")  // [S,1]
      let norm = graph.division(xc, denom, name: "ln_norm")  // [S,D]
    
      let gB = graph.expandDims(gamma, axes: [0], name: nil)  // [1,D]
      let bB = graph.expandDims(beta, axes: [0], name: nil)  // [1,D]
      let y = graph.addition(graph.multiplication(norm, gB, name: nil), bB, name: "ln_out")  // [S,D]
    
      return y
    }
    
    func geluTanhApprox(_ x: MPSGraphTensor, _ graph: MPSGraph) -> MPSGraphTensor {
      guard x.dataType == .float32 else {
        fatalError("Unsupported data type for GELU: \(x.dataType)")
      }
    
      let half = graph.constant(0.5, dataType: .float32)
      let one = graph.constant(1.0, dataType: .float32)
      let kA = graph.constant(0.7978845608028654, dataType: .float32)  // sqrt(2/pi)
      let kB = graph.constant(0.044715, dataType: .float32)
    
      let x3 = graph.multiplication(graph.multiplication(x, x, name: nil), x, name: "gelu_x3")
      let inner = graph.addition(x, graph.multiplication(kB, x3, name: nil), name: "gelu_inner")
      let tArg = graph.multiplication(kA, inner, name: "gelu_tanh_arg")
      let t = graph.tanh(with: tArg, name: "gelu_tanh")
      let y32 = graph.multiplication(
        graph.multiplication(half, x, name: nil),
        graph.addition(one, t, name: nil),
        name: "gelu_tanh_out")
      return x.dataType == .float32 ? y32 : graph.cast(y32, to: x.dataType, name: "gelu_cast_out")
    }
  • ¶

    Peek is a utility function to print out the contents of a tensor. Very helpful when debugging the output of the model step by step.

    func peek(_ td: MPSGraphTensorData, label: String, max: Int = 8) {
      let nda = td.mpsndarray()
      let shape = (0..<nda.numberOfDimensions).map { Int(nda.length(ofDimension: $0)) }
      let n = shape.reduce(1, *)
      var v = [Float](repeating: 0, count: n)
      nda.readBytes(&v, strideBytes: nil)
      let head = v.prefix(max).map { String(format: "%.9g", Double($0)) }.joined(separator: ", ")
      print("\(label) shape=\(shape)  [\(head)\(v.count > max ? ", ..." : "")]")
      print("Press Enter to continue...")
      _ = readLine()
    }
  • ¶

    GPT-2 Tokenizer The code was generated based on the original Python implementation by GPT-5. I didn’t write it because I wanted to focus on the LLM part of this project.

    public final class GPT2Tokenizer {
  • ¶

    encoder: subword -> id, decoder: id -> subword

      private let encoder: [String: Int]
      private let decoder: [Int: String]
  • ¶

    bpeRanks: (a,b) -> rank

      private let bpeRanks: [Pair: Int]
  • ¶

    byte-level reversible mapping

      private let byteEncoder: [UInt8: String]
      private let byteDecoder: [String: UInt8]
  • ¶

    cache for BPE of a single “pretoken” (word piece before merges)

      private var cache: [String: [String]] = [:]
  • ¶

    regex used by GPT-2 for pre-tokenization

      private let tokenPattern: NSRegularExpression
  • ¶

    Pair type for merges dictionary

      private struct Pair: Hashable {
        let a: String
        let b: String
      }
    
      public init(encoder: [String: Int], merges: [String]) throws {
        self.encoder = encoder
        self.decoder = Dictionary(uniqueKeysWithValues: encoder.map { ($1, $0) })
  • ¶

    Build bpeRanks from merges lines (skip first line if it’s a version header)

        var ranks: [Pair: Int] = [:]
        var startIndex = 0
        if let first = merges.first, first.hasPrefix("#") { startIndex = 1 }
        for (i, line) in merges[startIndex...].enumerated() {
          let trimmed = line.trimmingCharacters(in: .whitespacesAndNewlines)
          guard !trimmed.isEmpty else { continue }
          let parts = trimmed.split(separator: " ")
          guard parts.count == 2 else { continue }
          let pair = Pair(a: String(parts[0]), b: String(parts[1]))
          ranks[pair] = i
        }
        self.bpeRanks = ranks
  • ¶

    Byte<->Unicode mapping (exactly like OpenAI’s bytes_to_unicode)

        let (be, bd) = GPT2Tokenizer.makeByteUnicodeMaps()
        self.byteEncoder = be
        self.byteDecoder = bd
  • ¶

    GPT-2 tokenization regex

        let pattern =
          "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"
        self.tokenPattern = try NSRegularExpression(pattern: pattern, options: [.caseInsensitive])
      }
  • ¶

    Convenience: load from URLs

      public static func load(encoderJSON urlJSON: URL, mergesTXT urlBPE: URL) throws -> GPT2Tokenizer {
        let data = try Data(contentsOf: urlJSON)
        let enc = try JSONDecoder().decode([String: Int].self, from: data)
    
        let mergesContent = try String(contentsOf: urlBPE, encoding: .utf8)
        let merges = mergesContent.split(whereSeparator: \.isNewline).map(String.init)
    
        return try GPT2Tokenizer(encoder: enc, merges: merges)
      }
  • ¶

    / Encode text into GPT-2 token IDs

      public func encode(_ text: String) -> [Int32] {
        let pretokens = findAll(pattern: tokenPattern, in: text)
        var ids: [Int] = []
        ids.reserveCapacity(pretokens.count * 2)
    
        for tok in pretokens {
  • ¶

    1) byte-encode (UTF-8 bytes → safe Unicode mapping)

          let bstr = bytesToUnicodeString(Array(tok.utf8))
  • ¶

    2) run BPE over the mapped string

          let parts = bpe(bstr)
  • ¶

    3) map subwords to ids

          for p in parts {
            if let id = encoder[p] {
              ids.append(id)
            } else {
  • ¶

    In practice this shouldn’t happen with the official files Fallback: skip or assert assert(false, “Unknown BPE token (p)”)

            }
          }
        }
        return ids.map { Int32($0) }
      }
  • ¶

    / Decode GPT-2 token IDs back to String

      public func decode(_ i32Ids: [Int32]) -> String {
  • ¶

    1) map ids to subword strings and concatenate

        let ids = i32Ids.map { Int($0) }
        let text = ids.compactMap { decoder[$0] }.joined()
  • ¶

    2) map each Unicode char back to its original byte, then UTF-8 decode

        var bytes: [UInt8] = []
        bytes.reserveCapacity(text.count)
        for ch in text {
          let s = String(ch)
          if let b = byteDecoder[s] {
            bytes.append(b)
          } else {
  • ¶

    Should not happen if maps are complete

          }
        }
        return String(decoding: bytes, as: UTF8.self)
      }
    
      private func bpe(_ token: String) -> [String] {
        if let cached = cache[token] { return cached }
    
        var word: [String] = token.map { String($0) }
        guard !word.isEmpty else { return [] }
    
        var pairs = getPairs(of: word)
        if pairs.isEmpty {
          cache[token] = [word.joined()]
          return cache[token]!
        }
    
        while true {
  • ¶

    find best (lowest rank) pair

          var minRank = Int.max
          var bigram: Pair? = nil
          for p in pairs {
            if let r = bpeRanks[p], r < minRank {
              minRank = r
              bigram = p
            }
          }
          guard let merge = bigram else { break }
  • ¶

    merge all occurrences of merge in word

          var newWord: [String] = []
          var i = 0
          while i < word.count {
            if i < word.count - 1, word[i] == merge.a, word[i + 1] == merge.b {
              newWord.append(word[i] + word[i + 1])
              i += 2
            } else {
              newWord.append(word[i])
              i += 1
            }
          }
          word = newWord
          if word.count == 1 { break }
          pairs = getPairs(of: word)
        }
    
        let result = word
        cache[token] = result
        return result
      }
    
      private func getPairs(of word: [String]) -> Set<Pair> {
        var s: Set<Pair> = []
        guard word.count >= 2 else { return s }
        for i in 0..<(word.count - 1) {
          s.insert(Pair(a: word[i], b: word[i + 1]))
        }
        return s
      }
    
      private static func makeByteUnicodeMaps() -> ([UInt8: String], [String: UInt8]) {
  • ¶

    bs = visible ASCII + Latin-1 supplement chunks (¡..¬, ®..ÿ)

        var bs: [Int] = Array(33...126) + Array(161...172) + Array(174...255)
        var cs = bs
        var n = 0
        for b in 0...255 where !bs.contains(b) {
          bs.append(b)
          cs.append(256 + n)
          n += 1
        }
        var be: [UInt8: String] = [:]
        var bd: [String: UInt8] = [:]
        for (i, b) in bs.enumerated() {
          let scalar = UnicodeScalar(cs[i])!
          let ch = String(scalar)
          be[UInt8(b)] = ch
          bd[ch] = UInt8(b)
        }
        return (be, bd)
      }
    
      private func bytesToUnicodeString(_ bytes: [UInt8]) -> String {
        var out = String()
        out.reserveCapacity(bytes.count)
        for b in bytes {
          if let ch = byteEncoder[b] { out.append(contentsOf: ch) }
        }
        return out
      }
    
      private func findAll(pattern: NSRegularExpression, in text: String) -> [String] {
        let nsrange = NSRange(text.startIndex..<text.endIndex, in: text)
        return pattern.matches(in: text, options: [], range: nsrange).compactMap { match in
          Range(match.range, in: text).map { String(text[$0]) }
        }
      }
    }
  • ¶

    Alright, this is it! We’ve implemented a GPT-2 inference engine from scratch in Swift using only Apple’s MPSGraph framework. It can load the open-source weights and run inference on a given prompt.

    I’ve taken some shortcuts to keep the code simple – we would definitely need to manage memory better for example, but I think the gist of it is all here.

    I hope you found this useful and educational! If you have any questions or suggestions, please reach out to me at hello AT khamidou.com