Simulating Next Token Prediction: Transformers vs. Reinforcement Learning
Introduction
As a ML Engineer, I’ve always been fascinated by the different approaches to solving complex problems in machine learning. Recently, I decided to dive into an interesting experiment: comparing Transformers and Reinforcement Learning (RL) for a simple next token prediction task. In this blog post, I’ll share my journey of simulating next token prediction using both methods, focusing on the task of adding two-digit numbers.
The Challenge: Adding Two-Digit Numbers
I chose to work with two-digit addition for a few reasons:
- It’s simple enough to implement quickly
- It’s complex enough to showcase the strengths of both approaches
- It provides a clear benchmark for comparing performance
Dataset Curation
First things first, I needed a dataset. I created a simple Python script to generate pairs of two-digit numbers and their sums. Here’s a snippet of how I did it:
# dataset idea from https://github.com/karpathy/minGPT/blob/master/projects/adder/adder.py
def make_dataset():
ds = []
for i in range(100):
for j in range(100):
s = i+j
ds.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])
shuffle(ds)
ds = np.array(ds).astype(np.float32)
ds_X = ds[:, 0:6]
ds_Y = np.copy(ds[:, 1:])
train_size = int(len(ds_X) * 0.85)
ds_X_train, ds_X_test = ds_X[0:train_size], ds_X[train_size:]
ds_Y_train, ds_Y_test = ds_Y[0:train_size], ds_Y[train_size:]
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
# Dataset would look like this
# X_train = [array([6., 7., 4., 6., 1., 1.], dtype=float32), ...]
# Y_train = [array([7., 4., 6., 1., 1., 3.], dtype=float32), ...]
# First 2 numbers for addition are 67 and 46 (X_train)
# Their result would look like 113,
# For Y_train as we are shifting everything to left
# and adding the last number 3 to our dataset, so we would be predicting this
This gave me a dataset where each entry is a sequence of 6 digits: the first two pairs represent the numbers to be added, and the last pair is their sum.
Approach 1: Transformers
For the Transformer approach, I used a simple implementation based on tinygrad. The model configuration was as follows:
- Vocabulary Size: 10 (digits 0 to 9)
- Maximum Sequence Length: 6
- Number of Layers: 2
- Embedding Dimensions: 128
- Number of Attention Heads: 4
- Feedforward Network Dimensions: 32
https://github.com/tinygrad/tinygrad/blob/master/examples/transformer.py
#!/usr/bin/env python3
import numpy as np
import random
from tinygrad.nn.state import get_parameters
from tinygrad.nn.optim import Adam
from extra.training import train, evaluate
from extra.models.transformer import Transformer
if __name__ == "__main__":
model = Transformer(10, 6, 2, 128, 4, 32)
X_train, Y_train, X_test, Y_test = make_dataset()
lr = 0.003
for i in range(10):
optim = Adam(get_parameters(model), lr=lr)
train(model, X_train, Y_train, optim, 50, BS=64, allow_jit=True)
acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True)
lr /= 1.2
print(f'reducing lr to {lr:.4f}')
if acc > 0.998:
wrong=0
for k in range(len(Y_test_preds)):
if (Y_test_preds[k] != Y_test[k]).any():
wrong+=1
a,b,c,x = X_test[k,:2].astype(np.int32), X_test[k,2:4].astype(np.int32), Y_test[k,-3:].astype(np.int32), Y_test_preds[k,-3:].astype(np.int32)
print(f'{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})')
print(f'Wrong predictions: {wrong}, acc = {acc:.4f}')
Training the Transformer was surprisingly quick — it took less than 15 seconds on my machine. The results were impressive:
58 + 98 = 153 (correct: 156)
18 + 58 = 073 (correct: 076)
68 + 18 = 083 (correct: 086)
99 + 99 = 198 (correct: 198)
93 + 13 = 107 (correct: 106)
98 + 28 = 123 (correct: 126)
13 + 13 = 021 (correct: 026)
78 + 38 = 113 (correct: 116)
03 + 13 = 012 (correct: 016)
Wrong predictions: 9, acc = 0.9992 (Result may vary)
As you can see, even when the model made mistakes, it was often very close to the correct answer.
Approach 2: Reinforcement Learning
For the RL approach, I decided to use Proximal Policy Optimization (PPO) from the stable_baselines3
library. I created a custom Gym environment to simulate the addition task.
Introduction to Reinforcement Learning
Reinforcement Learning (RL) involves training agents to make a series of decisions by rewarding desired actions and penalizing undesired ones. In this case, our RL agent will learn to predict the last number of two-digit numbers by receiving rewards based on the accuracy of its predictions.
Creating a Gym Environment
The environment provides the agent with pairs of two-digit numbers and expects it to predict their sum. Here’s a simplified version of the environment
class AdderTokenPredictorEnv(gym.Env):
metadata = {}
def __init__(self, x, y) -> None:
super().__init__()
self.x = x
self.y = y
self.action_space = Discrete(start=0, n=10)
self.observation_space = Box(0, 10, shape=(len(x[0]),), dtype=np.float32)
self.len_of_dataset = len(x)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.wrong = 0
self.counter = 0
self.state = self.x[self.counter]
return self.state, {}
def step(self, action):
terminated = False
done = False
info = {}
y = self.y[self.counter]
correct_action = int(y[-1])
if action == correct_action:
reward = 1
else:
reward = -1
self.wrong += 1
self.counter += 1
if self.counter == self.len_of_dataset:
done = True
info["done"] = done
info["terminated"] = terminated
info["counter"] = self.counter
info["remaining"] = self.len_of_dataset - self.counter
info["wrong"] = self.wrong
info["correct"] = correct_action
info["predicted"] = action
info["x"] = self.state
info["y"] = y
if done or terminated:
return self.state, reward, done, terminated, info
self.state = self.x[self.counter]
return self.state, reward, done, terminated, info
def close(self):
pass
vec_env = DummyVecEnv([make_env(ds_X_train, ds_Y_train)] * 64)
eval_vec_env = DummyVecEnv([make_env(ds_X_test, ds_Y_test)])
ppo = PPO(
"MlpPolicy",
vec_env,
device="cuda",
learning_rate=0.0001,
n_steps=2048,
batch_size=64,
tensorboard_log="logs",
)
When setting up the RL environment, I decided to keep things as similar to the Transformer setup as possible. Just like with the Transformer, our RL agent only had to predict digits from 0 to 9, working with the same 6-digit array format. To start, I went with the simplest reward function I could think of: +1 for correct guesses, -1 for incorrect ones. Sometimes in machine learning, starting simple is the best way to understand the problem.
After letting the RL agent train for about 35 minutes (which translated to 4.5 million timesteps — quite a workout for my pc!), I was eager to see the results. The outcome? 12 incorrect predictions out of 1500. Not bad, but not great either. What really caught my eye was a peculiar pattern: the model seemed to have a hard time telling 7 and 8 apart. Here’s an example of what I was seeing:
Test_X Test_Y Correct Predicted
[6. 9. 0. 9. 0. 7.] [9. 0. 9. 0. 7. 8.] [8] [7]
[1. 9. 3. 9. 0. 5.] [9. 3. 9. 0. 5. 8.] [8] [7]
[5. 9. 8. 9. 1. 4.] [9. 8. 9. 1. 4. 8.] [8] [7]
[4. 9. 5. 9. 1. 0.] [9. 5. 9. 1. 0. 8.] [8] [7]
[7. 9. 4. 9. 1. 2.] [9. 4. 9. 1. 2. 8.] [8] [7]
[9. 9. 4. 9. 1. 4.] [9. 4. 9. 1. 4. 8.] [8] [7]
[9. 9. 8. 9. 1. 8.] [9. 8. 9. 1. 8. 8.] [8] [7]
[0. 9. 6. 9. 0. 7.] [9. 6. 9. 0. 7. 8.] [8] [7]
[3. 9. 8. 9. 1. 2.] [9. 8. 9. 1. 2. 8.] [8] [7]
[2. 9. 5. 9. 0. 8.] [9. 5. 9. 0. 8. 8.] [8] [7]
[9. 9. 0. 9. 1. 0.] [9. 0. 9. 1. 0. 8.] [8] [7]
[9. 9. 5. 9. 1. 5.] [9. 5. 9. 1. 5. 8.] [8] [7]
This confusion between 7 and 8 persisted even after I let it train for longer. It was like watching someone squint at a doctor’s handwriting — close, but not quite there. I realized it was time to try something new.
My first thought was to normalize the input data. In many machine learning tasks, normalization can help the model learn more effectively by putting all the inputs on a similar scale. So, I implemented normalization for all states in the training process, curious to see if this would help our agent distinguish its 7s from its 8s.
from stable_baselines3.common.vec_env import VecNormalize
vec_env = VecNormalize(vec_env)
eval_vec_env = VecNormalize(eval_vec_env)
The results were… interesting. The good news was that the model started learning faster — always a win in my book. But here’s the catch: we now had 26 incorrect predictions out of 1500. It seemed like we had taken one step forward and two steps back. At this point, I could almost hear the model asking for more training time. But I had another trick up my sleeve: a learning rate scheduler.
def linear_schedule(initial_value: float):
def func(progress_remaining: float) -> float:
return progress_remaining * initial_value
return func
ppo = PPO(
"MlpPolicy",
vec_env,
device="cuda",
learning_rate=linear_schedule(0.0001),
n_steps=2048,
batch_size=64,
tensorboard_log="logs",
)
Adding a linear learning rate scheduler felt like giving the model a personal trainer — adjusting the intensity of the learning process over time. But even with this new addition, I wasn’t seeing the improvements I had hoped for. It was clear that we needed to dig deeper.
That’s when it hit me — maybe our reward function was too simplistic. I decided to make it more dynamic, basing the reward on the difference between the agent’s prediction and the correct answer. This way, the agent would get more nuanced feedback about how close its guesses were.
# current
if action == correct_action:
reward = 1
else:
reward = -1
# after
if action == correct_action:
reward = 1
else:
reward = -abs(correct_action-action)
With this new reward function in place, I decided to run a full set of experiments: with and without normalization, and with and without the learning rate scheduler. It was like setting up a mini machine learning tournament.
And the winner is… drum roll, please… the RL PPO Agent with both normalization and the learning rate scheduler! This combination achieved 100% accuracy on our test dataset. No errors. Nada. Zip. Our little agent had finally learned to predict last single digit for addition of 2 numbers perfectly!
Conclusion
This experiment yielded some interesting insights:
- Transformers are incredibly efficient at learning patterns, achieving high accuracy in a very short training time.
- RL, while slower to train, can achieve perfect accuracy with the right tweaks.
- For RL, normalization consistently improved performance.
- The choice of reward function significantly impacted the results.
- Learning rate schedules varied in effectiveness depending on the specific task setup.
While this experiment is far from replicating the capabilities of advanced models like BERT or GPT, it provides an interesting glimpse into how different approaches can be applied to similar problems.
As a machine learning engineer, I found this experiment enlightening. It reinforced the importance of choosing the right tool for the job and the value of experimentation. While Transformers showed impressive out-of-the-box performance, the process of refining the RL approach led to deeper insights about the problem space.
In the future, I’m curious to explore how these approaches might scale to more complex tasks. Could RL be effectively applied to more advanced language tasks? How would the performance gap between Transformers and RL change as the problem complexity increases?
This experiment has opened up a world of questions, and I’m excited to continue exploring the intersection of different machine learning paradigms in my future work.
GitHub link — https://github.com/xaviruvpadhiyar98/rl-examples/blob/main/adder_next_token_prediction.py