Sitemap

Thinking Before You Speak, Finetuning Your Own Reasoning Models

Kaif
12 min readFeb 20, 2025

Ioften wonder about the fact that supervised learning makes up for good training impact on the model but is never good enough to say, “this is it,” or “this will probably solve that one error I’m stuck with,” or something I fear will “take my job.” But over the past few months, something I never anticipated that I would get really obsessed over and call “more than good” is Reinforcement Learning.

The idea of quite literally telling the model, “Hey, I like my outputs in Markdown, otherwise I hate you,” or, technically speaking, “a reward function” to supervise and train a model is quite fascinating to me. Over the last six months, we went from OpenAI’s o1-preview to SOTA reasoning models like DeepSeek R1 and o3-mini that function based on the previously well-researched concept of Chain-of-Thought (CoT) prompting and RLHF training methods. Today, I’ll be explaining my professional understanding of how an RLHF method works and how you can train a model like DeepSeek R1 on your own!

What is Chain-of-Thought (CoT)? ⛓️

But first, let’s understand what CoT prompting is. Chain-of-thought prompting is a technique that guides an LLM to break down a problem into a series of intermediate, logical steps before arriving at a final answer. It’s very similar to how psychologists suggest you think out loud when struggling with a complex task. This helps a person break down a problem into smaller bits so that implementing logic and arriving at a conclusion becomes clearer.

Let’s take an example:
For a simple linear equation solve for y → 5y-36 = 14

In this example a reasoning model would solve it like this:
Step 1: Add 36 to both sides → 5y = 50
Step 2: Divide by 5 on both sizes → 5y/5 = 50/5
Step 3: Final Answer → y = 10

By breaking down a problem, a reasoning model can curate its own work, leading to better and more accurate outcomes. If it makes a mistake, there’s a better chance the model will fix it during the “thinking” process.

Now that we’ve cleared that out, let’s move on to what RLHF is.

From RLHF to Reasoning Models.

The whole RLHF hype in plain simple words is: It is a technique used to fine-tune language models so that their outputs better align with human preferences. The training process is still supervised, except here’s what really changes:

Generally, LLMs are provided with the input sequence and the desired output sequence. For a summarization dataset, the model is provided with the text to be summarized and the summarized text for the corresponding input. What’s different in RLHF is that a model first generates multiple outputs for the given task. These outputs are then ranked by humans on multiple criteria like relevance, safety, etc. Next, we train a reward model that learns from the generated output and the corresponding ratings given to each output so the model knows what is expected from it. Finally, the main model (the LLM we intend to train) is rewarded by the reward model we trained earlier so that the model adjusts its parameters to produce the outputs a human would prefer. This is called Proximal Policy Optimization (PPO).

Now, this sounds like a lot of work because first, we have to rank the outputs ourselves, and second, we need to train two models in the process, increasing computational costs. Plus, transferring the reward data is quite unstable this way. Thus, upon further research, Direct Preference Optimization (DPO) was introduced as a fine-tuning method to train aligned models in a more direct and simple way, ensuring stability while handling the reward mechanism.

Recall that with PPO, we fine-tune our language model by using human feedback to train a separate reward model and then use that reward signal in a reinforcement learning loop. This involves policy gradients, careful tuning of learning rates, and ensuring the updates remain “proximal” or close enough to the original model to avoid instability.

But with DPO, on the other hand, it takes care of the alignment process in the same training loop itself. So basically, we fine-tune a single model that handles aligning the outputs based on reward scores during the main training process.

In DPO, we directly compare pairs of outputs: if a human prefers response A over response B for a given prompt. The input data works this way: for our input sequence (the task/question), this time we provide both a “chosen response” and a “rejected response.”

Direct Preference Optimization (DPO) (Rafailov et al., 2023) — https://arxiv.org/abs/2305.18290

So, DPO compares and optimizes the probability of the chosen response to be higher than the rejected response. The algorithm modifies the maximum log-likelihood of each chosen response to align the model with human preferences. This works with both techniques: RLHF (Reinforcement Learning with Human Feedback) and RLAIF (Reinforcement Learning with AI Feedback).

Leveraging RLHF concepts to obtain a “Reasoning Model”

Now, talking about the “star model” or the “SOTA model” or “so-called AGI,” DeepSeek R1 is trained using a newly introduced method, Group Relative Policy Optimization (GRPO), by the authors of DeepSeekMath. The paper shows how reasoning can help solve critical problems where logic is essential at certain steps for complex understanding in areas like coding, math, etc.

I know this is quite literally the third method I have explained so far, but I promise this is the last one we’ll be talking about.

While DPO reformulates alignment as a supervised ranking task, directly encouraging the model to assign higher probabilities to outputs that human evaluators prefer, GRPO takes a more holistic approach by evaluating a whole group of responses together for the same prompt.

Let’s take another example about a summarization task:
Prompt: “Summarize the following article about climate change: [Article Text]

Now, suppose our model generates three candidate summaries:
- Response A: “Climate change refers to long-term alterations in temperature and weather patterns, mainly due to human activities.”
- Response B: “Climate change is the gradual shift in global weather patterns and temperatures caused largely by greenhouse gas emissions from human activities.”
- Response C: “The article discusses how human actions, especially the burning of fossil fuels, are altering the Earth’s climate, leading to warmer temperatures and more extreme weather events.”

A reward function, designed to assess criteria like accuracy, conciseness, and informativeness, evaluates these outputs and assigns:
- Reward for A: 0.8
-
Reward for B: 0.7
-
Reward for C: 0.9

In GRPO, rather than evaluating a single output in isolation, we consider the entire group of responses for this prompt. If we go into nerd mode for a second and look at the internal math, here’s how this works:

  • Compute Group Statistics:
Compute Mean and STD
  • Calculate Group-Relative Advantages:
For each response, the advantage is calculated as: An = (r-μ​) / σ​

GRPO then uses a PPO-like objective function but don’t worry everything is still happening in the same training loop just like DPO but instead of using a separate value network to compute the baseline, it uses these group-relative advantages. In our example, Response C’s high advantage (+1.22) tells the model that, relative to the other candidates, it’s the best summary. The GRPO objective will adjust the model’s parameters so that the probability of generating responses like Response C increases, while those like Response B decrease due to its negative advantage.

Of course the actual math behind computing the advantage, loss, KL divergence is much more complicated.

That’s all great but how do I train my own reasoning model? 🤔💭

Well let’s talk about the part where the magic happens. We’ll be using the TRL library by HuggingFace that natively supports GRPO Finetuning. Since most of you guys (including me) don’t really want to pre-train a model of our own from scratch finetuning is the way to proceed. The entire idea here is finetune an instruct model (an instruction finetuned LLM like Qwen-Instruct or Llama-Instruct) to function like a reasoning model over a downstream task or a generalized dataset. This way the model “thinks before speaking” (getting too excited and generating the output) but rather really think about its own generated response and curate itself and try to be absolutely sure before answering the question.

Now, most of us don’t have Nvidia H100 GPUs at home, so GRPO Full-Finetuning can be quite compute-intensive and cause headaches over OOM (Out of Memory) errors. In that case, don’t worry — I got you! You can understand the code from the article here and later implement this using Unsloth which is a memory efficiency framework that makes your LLMs train faster than F1 cars 🏎️ with wayyy less memory or as what I like to call it as magic ✨. I’ll link up the notebooks below in references which can you directly run to see the action for yourselves.

The code I’ll be demonstrating below will only use TRL and full-finetune an Instruct model.

Note: As long as you are full finetuning a model >1.5B Parameters should do the job, in case you are using Unsloth with QLoRA, Unsloth recommends to use a model <1.5B Parameters to correctly generate those “reasoning tokens” that small models may struggle with.

Here’s what we will working with:

  1. Dataset: GSM8K Dataset by OpenAI
  2. Base Model: Qwen2.5–3B-Instruct
  3. Finetuning Mechanism: GRPOTrainer + LoRA Config

Env Setup & Importing libraries

Make sure you have the following packages installed and then import them as follows:

import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

Next we prep the dataset and setup the reward functions. The reward functions are the essence of getting our very own reasoning model.

Let’s start by defining 2 strings that establish the output format of the code:

  1. SYSTEM_PROMPT: Instructs the model to generate responses that include a <reasoning> block and an <answer> block, ensuring all thinking tokens and the actual answer stay separated.
  2. XML_COT_FORMAT: Provides a template with placeholders ({reasoning} and {answer}) that will be replaced with actual reasoning tokens and answers.

The formatting part makes it essential that we choose the base model as an instruction finetuned model as much as possible to maintain said consistency.

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

Next we quickly preprocess the dataset to follow the structure we defined earlier to maintain the consistency.

def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()

def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip().replace(",", "").replace("$", "")

def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split]
data = data.map(lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
})
return data

Now we’ll define the✨ reward functions ✨. Reward functions allow us to generate responses in the desired structure. We’ll be passing reward functions for structural, formatting and correctness of the output.

So every time the model follows the instructions provided in the reward functions, the Trainer rewards the model which simply encourages the model to keep pushing in the same direction.

We’ll go through each reward function mentioned below:

  1. The correctness_reward_func checks the correctness of each answer by extracting the output produced by the model and the expected answer and verifies if they match exactly or not. If Yes, great here’s 2 points for the model otherwise 0.
  2. The init_reward_func here again extracts the output produced by the model and check if the generated response consists solely of digits since this is a math dataset and we expect the output to be a number. The reward calculation criteria here is 0.5 if the answer is a digit otherwise 0.
  3. Next, the strict_format_reward_func & soft_format_reward_func applies a regex pattern that strictly enforces the expected format to have <reasoning> and <answer> tags in separate lines. Now this is a difficult curve that we tackle so we reward the model once if the expected response is exactly the same and additionally set a more lenient pattern that just checks the presence of <reasoning> </reasoning> and <answer> </answer> from the extracted response. The reward calculation criteria here is 0.5 if the answer is a digit otherwise 0.
  4. Lastly, the xmlcount_reward_func checks the formatting of the model’s output by verifying that it contains the required XML tags in the correct order and structure. It looks for:
  • Exactly one <reasoning>\n and one \n</reasoning>\n, awarding 0.125 points each.
  • Exactly one \n<answer>\n and one \n</answer>, again adding 0.125 points each.
  • Additionally, we subtract a small penalty of 0.001 per extra character if there’s any extra text after the <answer> or </answer> tags, ensuring that the model's response is precise.
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]

Now that we have all that out of the way, we’ll initialize GRPO and LoRA configs.

model_name = "Qwen/Qwen2.5-3B-Instruct"
output_dir = "outputs/MathAwesome-Qwen-3B"
run_name = "Qwen-3B-GRPO-gsm8k"
training_args = GRPOConfig(
output_dir = output_dir,
run_name = run_name,
learning_rate = 5e-6,
weight_decay = 0.1,
lr_scheduler_type = 'cosine',
logging_steps = 1,
bf16 = True, # Disable if you don't have Ampere GPU or higher
per_device_train_batch_size = 1,
gradient_accumulation_steps = 4,
num_generations = 16, # The GRPO Magic ✨
use_vllm=True, # For faster generations, do !pip install vllm
max_prompt_length = 256,
max_completion_length = 786,
num_train_epochs = 1,
save_steps = 100, # Checkpoint Saving
max_grad_norm=0.1,
report_to="wandb", # Set to None if you don't want to use wandb but highly reccomended
log_on_each_node=False, # For Multi-GPU Setup
)

The num_generations argument here defines how many generations per prompt are sampled.

Next we pass the LoRA Config. Note: LoRA finetuning is not working with multi-gpu setup.

peft_config = LoraConfig(
r=16,
lora_alpha=64,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"down_proj",
"gate_proj"
],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)

Make sure you have flash-attention installed for faster training.

pip install ninja
pip install flash-attn --no-build-isolation

Finally we load the base model and tokenizer to pass it to the trainer.

model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16, # Set to torch.float16 if you don't have Ampere GPU or newer
attn_implementation="flash_attention_2", # Make sure you have FA2 installed
device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

Finally we initialize the trainer.

trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
], # The GRPO Magic ✨
args=training_args,
train_dataset=dataset,
peft_config=peft_config # Pass only for Single-GPU training
)

Finally we fire up the trainer 🔥

trainer.train()

And voilà, you now have your own reasoning model!

While finetuning takes place you can monitor the following metrics logged by wandb.

  • completion_length: The average completion length.
  • reward/{reward_func_name}: The reward computed by each reward function. We should have 5 separate graphs generated since we pass 5 functions and monitor individual performance for all the function.
  • reward: The average reward.
  • reward_std : The average standard deviation within reward groups.
  • kl : The average KL divergence between the model and the reference model calculated on completions.

Once your model is trained and ready you use it should be generating responses like this:

--

--

Kaif
Kaif

Written by Kaif

Enhancing accessibility through #NLProc | B. Tech CSE | AIML | DL | NLP | LLMs

No responses yet