23 min read
let's build our own AI model

all code in this blog post can be found on my gitlab

what you need:

  • computer (this example is made on a Macbook)
  • python 3.10
  • keyboard
  • internet
  • a lot of pip modules

to get started make sure you make a folder to dump your code in and to create a python virtual environment by running python3 -m venv venv

now, like i said, you’ll need quite a few pip modules.

now to save you a big headache, you can execute this script to install them all (make sure you are inside of a virtual environment to prevent installing a lot of garbage on your main environment)

first, never execute a script without checking what’s inside it:

curl -s https://gitlab.giotje.dev/bunny-sh/gist/-/raw/main/python/install_pip_packages_for_tinygrad

img

all this does is download the requirements.txt and then installing it

then if you’re oke with installing it you can run this:

curl -s https://gitlab.giotje.dev/bunny-sh/gist/-/raw/main/python/install_pip_packages_for_tinygrad | bash

the pipe(|) pipes the content of the file to bash, which then executes it.

you’ll see a lot of text, and then when it’s installed you’re ready.

now go to your code editor of choice and create a file called main.py

img

first we will start by importing all of our required libraries

#!/usr/bin/python3
from typing import Optional, Union
import argparse
from tqdm import trange
import numpy as np
import tiktoken
from tinygrad import Tensor, TinyJit, Device, GlobalCounters
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored
from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.shape.symbolic import Variable

what these do will become clear once we go on

now let’s define the context for our model. the context defines how much the model will remember from previous prompts.

normal computers do not remember context, when i tell my computer i’m in Germany, and then in a separate question ( instruction) ask it what the weather is; it will return the weather of my current location (most likely) - it has long forgotten that i was even talking about Germany.

this is what separates normal computers from Artificial Intelligence, they keep context in mind.

Remember that the more context you allow your model to remember, the more memory (RAM or VRAM) your computer needs to remember everything.

MAX_CONTEXT = getenv('MAX_CONTEXT', 128)
HALF = getenv('HALF')

here we check if MAX_CONTEXT is already defined in an environment variable, if not set it to 128

now let’s write our first class called Attention

class Attention:
  def __init__(self, dim, n_heads):
    self.c_attn = Linear(dim, 3*dim, bias=True)
    self.c_proj = Linear(dim, dim, bias=True)
    self.n_heads = n_heads
    self.dim = dim
    self.head_dim = dim // n_heads

  def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
    if mask is not None or start_pos.val == 0:
      # no symbolic shape qkv when consuming prompts
      start_pos = start_pos.val

    if HALF: x = x.half()
    xqkv = self.c_attn(x)
    xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(None, None, self.n_heads, self.head_dim) for i in range(3)]
    bsz, seqlen, _, _ = xq.shape

    # create kv cache
    if not hasattr(self, "cache_kv"):
      self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype)

    if start_pos > 0:
      keys = self.cache_kv[0].shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
      values = self.cache_kv[1].shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
    else:
      keys = xk
      values = xv

    # update the cache
    new_cache = Tensor.stack([keys, values]).pad((None, None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()
    self.cache_kv.assign(new_cache).realize()

    xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
    return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, self.dim))

What Is Attention?

in AI, “attention” is a mechanism that helps models decide which parts of the input are more important at any given time. imagine you’re reading a book. You don’t focus on every word equally—you naturally pay more attention to the important words and phrases. in AI, attention does the same thing, allowing the model to focus on certain pieces of data more than others, which is crucial for tasks like understanding long sentences or complex images.

this is a complete implementation of an attention mechanism, which allows the model to selectively focus on certain parts of the input data. here’s a quick breakdown:

• the __init__ method sets up the linear transformations that will be applied to the input. one transformation is used to split the input into three parts: queries (q), keys (k), and values (v). another transformation (c_proj) is used to apply the final result after attention.

• in the __call__ method, the input data is processed in chunks of q, k, and v. these are split and reshaped based on the number of heads (n_heads). this allows the model to look at different parts of the input simultaneously.

• the attention class also uses a cache to store previously computed keys and values, which helps in cases where the model needs to keep track of previous inputs (like when dealing with long sequences).

• finally, after applying attention using scaled_dot_product_attention, the output is passed through the final linear layer (c_proj) before being returned.

now let’s create our second class called FeedForward

class FeedForward:
  def __init__(self, dim, hidden_dim):
    self.c_fc = Linear(dim, hidden_dim, bias=True)
    self.c_proj = Linear(hidden_dim, dim, bias=True)

  def __call__(self, x:Tensor) -> Tensor:
    return self.c_proj(self.c_fc(x).gelu())

what’s this?

FeedForward is a type of layer that helps the model learn more complex patterns by transforming the input data in two steps. here’s how it works:

• the __init__ method sets up two linear transformations:

• c_fc transforms the input from dim size to hidden_dim.

• c_proj brings it back from hidden_dim to dim.

• in the __call__ method, the input data (x) is first passed through c_fc, which expands it into a higher-dimensional space. then, it applies a non-linear activation function called gelu() (gaussian error linear unit), which helps the model handle more complex patterns. finally, it passes the result through c_proj to bring the data back to its original dimensionality.

this layer is often used in combination with the Attention layer we discussed earlier, as part of a larger transformer block. while attention allows the model to focus on specific parts of the input, the feedforward layer helps it process the data in a deeper, more abstract way.

talking about transformers, let’s write our TransformerBlock and Transformer:

class TransformerBlock:
  def __init__(self, dim, n_heads, norm_eps):
    self.attn = Attention(dim, n_heads)
    self.mlp = FeedForward(dim, 4*dim)
    self.ln_1 = LayerNorm(dim, norm_eps)
    self.ln_2 = LayerNorm(dim, norm_eps)

  def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]):
    h = x + self.attn(self.ln_1(x), start_pos, mask).float()
    return (h + self.mlp(self.ln_2(h)))

class Transformer:
  def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
    self.vocab_size = vocab_size
    self.wte = Embedding(vocab_size, dim)
    self.wpe = Embedding(max_seq_len, dim)
    self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
    self.ln_f = LayerNorm(dim, norm_eps)
    self.lm_head = Linear(dim, vocab_size, bias=False)
    self.forward_jit = TinyJit(self.forward)

  def forward(self, tokens:Union[Tensor,Variable], start_pos:Variable, temperature:float=0.0):
    if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
    if isinstance(tokens, Variable):
      seqlen = 1
      tok_emb = self.wte.weight.shrink(((tokens, tokens+1), None))
    else:
      seqlen = tokens.shape[1]
      tok_emb = self.wte(tokens)

    pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
    h = tok_emb + pos_emb

    if HALF: h = h.half()

    mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos.val+1) if seqlen > 1 else None

    for hi in self.h: h = hi(h, start_pos, mask)

    logits = self.lm_head(self.ln_f(h))

    if logits.shape[1] == 0:
      # special case for empty prompt
      logits = Tensor.ones((logits.shape[0], self.vocab_size), dtype=logits.dtype, device=logits.device)
    else:
      logits = logits[:, -1, :]

    if temperature < 1e-6:
      ret = logits.argmax(-1)
    else:
      ret = (logits / temperature).softmax().multinomial()
    return ret.flatten().realize()

  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
    forward = (self.forward_jit if (isinstance(tokens, Variable) or tokens.shape[1] == 1) and getenv("JIT") else self.forward)
    return forward(tokens, start_pos, temperature)

????

the Transformer class is the main structure that ties everything together. it contains a stack of TransformerBlock layers and handles the embedding of input tokens into a vector space that the model can work with. here’s a breakdown: • in the __init__ method, several components are set up:

• self.wte: this is the word embedding layer, which converts tokens (words) into vectors of size dim.

• self.wpe: this is the position embedding layer, which helps the model understand the order of the tokens.

• self.h: this is a list of TransformerBlock layers. the model processes data through multiple transformer blocks to refine its understanding.

• self.ln_f: this is a final layer normalization applied before output.

• self.lm_head: this linear layer generates the final output (logits), which represents the model’s predictions.

• self.forward_jit: this wraps the forward method with a just-in-time (jit) compiler for faster execution.

• the forward method is where the actual magic happens:

  1. input tokens are passed through the embedding layers (wte for token embeddings and wpe for position embeddings). this turns them into vectors the model can process.
  2. if the input is a single token (like during generation), special handling is applied.
  3. the model applies a mask to ensure that the prediction for each token only depends on previous tokens, not future ones.
  4. the input is passed through all the transformer blocks (self.h), each one refining the representation of the data.
  5. the final result is normalized and passed through the lm_head layer, which produces the logits—the output predictions.
  6. depending on the temperature value, either the most likely prediction is selected (if temperature is near 0) or predictions are sampled based on probability.

the __call__ method wraps the forward function, deciding whether to use the forward_jit optimized version or the regular forward based on the input.

now let’s continue:

VOCAB_SIZE = 50257
MODEL_PARAMS = {
  'gpt2':         dict(n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE),   # 124M params
  'gpt2-medium':  dict(n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE),  # 350M params
  'gpt2-large':   dict(n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE),  # 774M params
  'gpt2-xl':      dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE),  # 1558M params
}

oke… maybe i lied… obviously i do not have the compute, time and will to spend hundreds of millions of dollars, ZETABYTES of compute power and years of my life to scrape a bunch of data.

and thankfully we don’t have to, cause OpenAI already did. in this example we are using the OpenSource GPT-2 model to feed our AI model.

now about the code:

the variable VOCAB_SIZE is simply the size of the vocabulary the model will be working with. in this case, the model is designed to understand 50,257 unique tokens (words or subwords). this is important because the model needs to map every word it sees into a number within this range.

next, MODEL_PARAMS is a dictionary that stores the configuration for different sizes of gpt-2 models. each model size has its own set of parameters that define how big and complex the model is. here’s a breakdown of the key parameters:

• n_layers: the number of transformer blocks in the model. more layers allow the model to learn deeper and more complex patterns.

• n_heads: the number of attention heads in each block. having more heads means the model can focus on different parts of the input simultaneously.

• dim: the size of the hidden layers (the width of each block). larger dimensions allow the model to handle more information.

• norm_eps: a small value used in layer normalization to prevent division by zero.

• vocab_size: this matches the VOCAB_SIZE we defined earlier, telling the model how many different tokens it needs to handle.

each version of the model is larger and more powerful than the previous one:

• gpt2 has 12 layers and 12 heads, making it the smallest model with around 124 million parameters.

• gpt2-medium is larger, with 24 layers and 16 heads, totaling about 350 million parameters.

• gpt2-large grows to 36 layers and 20 heads, with 774 million parameters.

• gpt2-xl is the biggest, with 48 layers and 25 heads, reaching a whopping 1.5 billion parameters.

in short, this dictionary defines the architecture for each version of the gpt-2 model, allowing you to easily switch between different sizes based on your needs and computing power.

now let’s define our GPT2 class

class GPT2:
  @staticmethod
  def build(model_size="gpt2"):
    tokenizer = tiktoken.get_encoding("gpt2")

    model = Transformer(**MODEL_PARAMS[model_size])
    weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
    # special treatment for the Conv1D weights we need to transpose
    transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
    for k in weights:
      if k.endswith(transposed):
        weights[k] = weights[k].T
    # lm head and wte are tied
    weights['lm_head.weight'] = weights['wte.weight']

    load_state_dict(model, weights)

    if HALF:
      for l in get_state_dict(model).values():
        l.assign(l.half().realize())

    return GPT2(model, tokenizer)

  def __init__(self, model, tokenizer):
    self.model = model
    self.tokenizer = tokenizer

  def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
    prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
    toks = [prompt_tokens[:] for _ in range(batch_size)]
    start_pos = 0
    for _ in trange(max_length, disable=(timing==True)):
      GlobalCounters.reset()
      if timing: print("")
      st = GlobalCounters.time_sum_s
      with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
                  f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
                  (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
        if batch_size == 1 and len(toks[0][start_pos:]) == 1:
          tokens = Variable("tokens", 0, VOCAB_SIZE).bind(toks[0][start_pos])
        else:
          tokens = Tensor([x[start_pos:] for x in toks])
        tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature).numpy().tolist()
      start_pos = len(toks[0])
      for i,t in enumerate(tok): toks[i].append(t)
    return [self.tokenizer.decode(x) for x in toks]

build method

this is a static method that builds the gpt-2 model. it does a few important things:

• first, it retrieves the gpt-2 tokenizer using tiktoken. this tokenizer is responsible for converting text into a format the model can process.

• next, it creates a transformer model using the parameters for the specified model size (like gpt2, gpt2-medium, etc.) from MODEL_PARAMS.

• it then downloads pre-trained weights for the model from huggingface (the pytorch_model.bin file), which contains all the learned parameters for the model.

• certain weights need to be transposed (like attn.c_attn.weight, attn.c_proj.weight, etc.), so the code handles that by flipping them appropriately.

• the model ties the lm_head.weight (used for generating predictions) and the wte.weight (used for embedding the tokens), which allows for more efficient processing.

• if the HALF flag is set, all the model’s weights are converted to half precision (fp16), which reduces memory usage and can speed up computations on certain hardware.

• finally, the method returns an instance of the GPT2 class, containing both the model and tokenizer.

__init__ method

this is the constructor for the GPT2 class. it simply stores the transformer model and tokenizer for later use.

generate method

this is where the magic happens. the generate method allows the model to generate text based on a given prompt. here’s how it works:

• the prompt is tokenized using the tokenizer. this converts the input text into a list of tokens that the model can understand.

• a list of token sequences (toks) is created for each batch. each sequence starts with the same prompt tokens. • the method then enters a loop to generate new tokens one by one, up to the specified max_length. it uses the trange function to show progress.

• for each iteration, the model processes the input tokens:

• if the batch size is 1 and only one token remains, the input tokens are wrapped in a Variable. otherwise, they are handled as a Tensor.

• the model predicts the next token based on the input sequence, temperature, and starting position. it adds that token to the sequence.

• the process repeats until the desired length is reached.

• after generating all tokens, the method decodes them back into text using the tokenizer and returns the generated text.

the timing option is useful for debugging or performance tracking. it tracks how long each iteration of model inference takes and outputs detailed performance metrics if enabled.

now, let’s write our main code

if __name__ == "__main__":
  Tensor.no_grad = True
  print(f"using {Device.DEFAULT} backend")
  default_prompt = "What is the answer to life, the universe, and everything?"

  parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to start with")
  parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate")
  parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax")
  parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
  parser.add_argument('--timing', action='store_true', help="Print timing per token")
  parser.add_argument('--seed', type=int, help="Set the random seed")
  parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
  parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
  parser.add_argument('--noshow', action='store_true', help="Don't show the output")
  args = parser.parse_args()

  if args.seed is not None:
    Tensor.manual_seed(args.seed)
    np.random.seed(args.seed)

  print(f"using {args.model_size}")
  gpt2 = GPT2.build(args.model_size)

  if args.benchmark != -1:
    gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
  else:
    texts = gpt2.generate(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
    if not args.noshow:
      print('Generating text...')
      if len(texts) == 1: print(texts[0])
      else:
        for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)

    # validate output!
    if args.temperature == 0 and args.model_size == "gpt2-medium" and args.count == 10:
      expected = {
        default_prompt: "What is the answer to life, the universe, and everything?\n\nThe answer is that we are all one",
        "Hello.": "Hello. I'm a little late to the party, but",
      }
      try:
        assert texts[0] == expected[args.prompt]
        print(colored("output validated", "green"))
      except KeyError:
        pass

this block of code is where the entire gpt-2 setup comes together. it’s the entry point of the script and handles user input, model setup, and text generation. let’s break it down:

setup

• Tensor.no_grad = True: this tells the model to disable gradient calculations, which we don’t need during inference ( when we’re just generating text). this makes the model faster.

• print(f”using {Device.DEFAULT} backend”): this prints which backend (cpu, gpu, etc.) is being used to run the model.

• default_prompt: a string that will be used as the default starting prompt for text generation if the user doesn’t provide one.

argument parsing

the script uses argparse to allow you to pass in various arguments when running it. here are the arguments you can pass:

• —prompt: the initial text prompt the model starts generating from. the default is “what is the answer to life, the universe, and everything?”.

• —count: the maximum number of tokens to generate (default is 100).

• —temperature: this controls the randomness of the model’s output. lower temperatures give more deterministic results, while higher values add variety.

• —model_size: allows you to choose which version of the gpt-2 model to use (e.g., gpt2, gpt2-medium, etc.). default is gpt2-medium.

• —timing: if enabled, this prints out how long each token takes to generate (useful for performance tracking).

• —seed: if set, this ensures that the text generation is repeatable by fixing the random number generator’s seed.

• —batch_size: this controls how many sequences the model processes at once.

• —benchmark: if set, the model runs a performance benchmark instead of generating text. it runs for a specified number of tokens.

• —noshow: if enabled, the script will generate text but not print it out.

setting the seed

if the user provides a seed, it sets the seed for both the Tensor and numpy’s random number generators. this ensures that the generated text is reproducible.

building the model

the script builds the gpt-2 model based on the chosen model size (default is gpt2-medium) by calling GPT2.build().

running the model

• if —benchmark is set, the script runs a performance benchmark using randomly generated tokens. this tests how fast the model can process tokens but doesn’t generate meaningful text. • otherwise, the script calls the generate method of the gpt-2 model, which generates text based on the provided prompt, token count, temperature, and batch size. • if —noshow isn’t set, the script prints the generated text to the console. if there are multiple responses (from a larger batch size), it prints each one with a label (e.g., “response 1”).

output validation

there’s a special check to validate the output for a specific combination of parameters:

• if the prompt is default_prompt, the model size is gpt2-medium, the temperature is 0, and the token count is 10, the script checks whether the output matches an expected result.

• if the output matches the expected text, it prints “output validated” in green.

now is the time to run this puppy. execute the following command: python3 ./main.py

now this can take some time, first it will download the data from HuggingFace, and load that into RAM.

Once that’s done you will be greeted by a message generated by your very own AI!

img

you can also input your very own prompt with the --prompt argument.

img

now remember the arguments we added?

you can increase the token count which will increase the likelihood of our AI to find the correct answer with --count, you can use the bigger GPT2 models (requires a lot more RAM) with --model_size, the sky is the limit!

complete code:

#!/usr/bin/env python3
from typing import Optional, Union
import argparse
from tqdm import trange
import numpy as np
import tiktoken
from tinygrad import Tensor, TinyJit, Device, GlobalCounters
from tinygrad.helpers import Timing, DEBUG, getenv, fetch, colored
from tinygrad.nn import Embedding, Linear, LayerNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.shape.symbolic import Variable

MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
HALF = getenv("HALF")


class Attention:
    def __init__(self, dim, n_heads):
        self.c_attn = Linear(dim, 3 * dim, bias=True)
        self.c_proj = Linear(dim, dim, bias=True)
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads

    def __call__(
        self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]
    ) -> Tensor:
        if mask is not None or start_pos.val == 0:
            # no symbolic shape qkv when consuming prompts
            start_pos = start_pos.val

        if HALF:
            x = x.half()
        xqkv = self.c_attn(x)
        xq, xk, xv = [
            xqkv.shrink((None, None, (i * self.dim, (i + 1) * self.dim))).reshape(
                None, None, self.n_heads, self.head_dim
            )
            for i in range(3)
        ]
        bsz, seqlen, _, _ = xq.shape

        # create kv cache
        if not hasattr(self, "cache_kv"):
            self.cache_kv = Tensor.zeros(
                2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype
            )

        if start_pos > 0:
            keys = (
                self.cache_kv[0]
                .shrink((None, (0, start_pos), None, None))
                .cat(xk, dim=1)
            )
            values = (
                self.cache_kv[1]
                .shrink((None, (0, start_pos), None, None))
                .cat(xv, dim=1)
            )
        else:
            keys = xk
            values = xv

        # update the cache
        new_cache = (
            Tensor.stack([keys, values])
            .pad((None, None, (0, MAX_CONTEXT - start_pos - seqlen), None, None))
            .contiguous()
        )
        self.cache_kv.assign(new_cache).realize()

        xq, keys, values = (
            xq.transpose(1, 2),
            keys.transpose(1, 2),
            values.transpose(1, 2),
        )
        return self.c_proj(
            xq.scaled_dot_product_attention(keys, values, mask)
            .transpose(1, 2)
            .reshape(bsz, seqlen, self.dim)
        )


class FeedForward:
    def __init__(self, dim, hidden_dim):
        self.c_fc = Linear(dim, hidden_dim, bias=True)
        self.c_proj = Linear(hidden_dim, dim, bias=True)

    def __call__(self, x: Tensor) -> Tensor:
        return self.c_proj(self.c_fc(x).gelu())


class TransformerBlock:
    def __init__(self, dim, n_heads, norm_eps):
        self.attn = Attention(dim, n_heads)
        self.mlp = FeedForward(dim, 4 * dim)
        self.ln_1 = LayerNorm(dim, norm_eps)
        self.ln_2 = LayerNorm(dim, norm_eps)

    def __call__(self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]):
        h = x + self.attn(self.ln_1(x), start_pos, mask).float()
        return h + self.mlp(self.ln_2(h))


class Transformer:
    def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
        self.vocab_size = vocab_size
        self.wte = Embedding(vocab_size, dim)
        self.wpe = Embedding(max_seq_len, dim)
        self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
        self.ln_f = LayerNorm(dim, norm_eps)
        self.lm_head = Linear(dim, vocab_size, bias=False)
        self.forward_jit = TinyJit(self.forward)

    def forward(
        self,
        tokens: Union[Tensor, Variable],
        start_pos: Variable,
        temperature: float = 0.0,
    ):
        if not hasattr(self, "allpos"):
            self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
        if isinstance(tokens, Variable):
            seqlen = 1
            tok_emb = self.wte.weight.shrink(((tokens, tokens + 1), None))
        else:
            seqlen = tokens.shape[1]
            tok_emb = self.wte(tokens)

        pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos + seqlen))))
        h = tok_emb + pos_emb

        if HALF:
            h = h.half()

        mask = (
            Tensor.full(
                (1, 1, seqlen, start_pos.val + seqlen), float("-inf"), dtype=h.dtype
            ).triu(start_pos.val + 1)
            if seqlen > 1
            else None
        )

        for hi in self.h:
            h = hi(h, start_pos, mask)

        logits = self.lm_head(self.ln_f(h))

        if logits.shape[1] == 0:
            # special case for empty prompt
            logits = Tensor.ones(
                (logits.shape[0], self.vocab_size),
                dtype=logits.dtype,
                device=logits.device,
            )
        else:
            logits = logits[:, -1, :]

        if temperature < 1e-6:
            ret = logits.argmax(-1)
        else:
            ret = (logits / temperature).softmax().multinomial()
        return ret.flatten().realize()

    def __call__(
        self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0
    ) -> Tensor:
        forward = (
            self.forward_jit
            if (isinstance(tokens, Variable) or tokens.shape[1] == 1) and getenv("JIT")
            else self.forward
        )
        return forward(tokens, start_pos, temperature)


VOCAB_SIZE = 50257
MODEL_PARAMS = {
    "gpt2": dict(
        n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE
    ),  # 124M params
    "gpt2-medium": dict(
        n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE
    ),  # 350M params
    "gpt2-large": dict(
        n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE
    ),  # 774M params
    "gpt2-xl": dict(
        n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE
    ),  # 1558M params
}


class GPT2:
    @staticmethod
    def build(model_size="gpt2"):
        tokenizer = tiktoken.get_encoding("gpt2")

        model = Transformer(**MODEL_PARAMS[model_size])
        weights = torch_load(
            fetch(f"https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin")
        )
        # special treatment for the Conv1D weights we need to transpose
        transposed = (
            "attn.c_attn.weight",
            "attn.c_proj.weight",
            "mlp.c_fc.weight",
            "mlp.c_proj.weight",
        )
        for k in weights:
            if k.endswith(transposed):
                weights[k] = weights[k].T
        # lm head and wte are tied
        weights["lm_head.weight"] = weights["wte.weight"]

        load_state_dict(model, weights)

        if HALF:
            for l in get_state_dict(model).values():
                l.assign(l.half().realize())

        return GPT2(model, tokenizer)

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate(
        self,
        prompt: str,
        max_length: int,
        temperature: float,
        timing: bool = False,
        batch_size: int = 1,
    ):
        prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
        toks = [prompt_tokens[:] for _ in range(batch_size)]
        start_pos = 0
        for _ in trange(max_length, disable=(timing == True)):
            GlobalCounters.reset()
            if timing:
                print("")
            st = GlobalCounters.time_sum_s
            with Timing(
                "ran model in ",
                on_exit=(
                    (
                        lambda et: (
                            f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"
                            if DEBUG >= 2
                            else ""
                        )
                        + f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"
                        + (
                            f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s"
                            if DEBUG >= 2
                            else ""
                        )
                    )
                    if DEBUG
                    else None
                ),
                enabled=timing,
            ):
                if batch_size == 1 and len(toks[0][start_pos:]) == 1:
                    tokens = Variable("tokens", 0, VOCAB_SIZE).bind(toks[0][start_pos])
                else:
                    tokens = Tensor([x[start_pos:] for x in toks])
                tok = (
                    self.model(
                        tokens,
                        Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(
                            start_pos
                        ),
                        temperature,
                    )
                    .numpy()
                    .tolist()
                )
            start_pos = len(toks[0])
            for i, t in enumerate(tok):
                toks[i].append(t)
        return [self.tokenizer.decode(x) for x in toks]


# **** main code ****

if __name__ == "__main__":
    Tensor.no_grad = True
    print(f"using {Device.DEFAULT} backend")
    default_prompt = "What is the answer to life, the universe, and everything?"

    parser = argparse.ArgumentParser(
        description="Run GPT2 in tinygrad",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--prompt", type=str, default=default_prompt, help="Phrase to start with"
    )
    parser.add_argument(
        "--count", type=int, default=100, help="Max number of tokens to generate"
    )
    parser.add_argument(
        "--temperature", type=float, default=0.8, help="Temperature in the softmax"
    )
    parser.add_argument(
        "--model_size",
        type=str,
        default="gpt2-medium",
        help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]",
    )
    parser.add_argument("--timing", action="store_true", help="Print timing per token")
    parser.add_argument("--seed", type=int, help="Set the random seed")
    parser.add_argument(
        "--batch_size", type=int, default=1, help="Set the input batch size"
    )
    parser.add_argument(
        "--benchmark",
        type=int,
        default=-1,
        help="Benchmark GPT with the given number of tokens",
    )
    parser.add_argument("--noshow", action="store_true", help="Don't show the output")
    args = parser.parse_args()

    if args.seed is not None:
        Tensor.manual_seed(args.seed)
        np.random.seed(args.seed)

    print(f"using {args.model_size}")
    gpt2 = GPT2.build(args.model_size)

    if args.benchmark != -1:
        gpt2.model(
            Tensor.rand(args.batch_size, args.benchmark),
            Variable("a", 0, MAX_CONTEXT).bind(0),
        ).realize()
    else:
        texts = gpt2.generate(
            args.prompt,
            args.count,
            args.temperature,
            timing=args.timing,
            batch_size=args.batch_size,
        )
        if not args.noshow:
            print("Generating text...")
            if len(texts) == 1:
                print(texts[0])
            else:
                for i, text in enumerate(texts):
                    print(colored(f"Response {i}:", "green"), text)

        # validate output!
        if (
            args.temperature == 0
            and args.model_size == "gpt2-medium"
            and args.count == 10
        ):
            expected = {
                default_prompt: "What is the answer to life, the universe, and everything?\n\nThe answer is that we are all one",
                "Hello.": "Hello. I'm a little late to the party, but",
            }
            try:
                assert texts[0] == expected[args.prompt]
                print(colored("output validated", "green"))
            except KeyError:
                pass

you can also find it on my gitlab

Thank you for reading this lengthy post, and i hope u found it interesting!

Ciao for now,

Giovannni