openai/gsm8k
Benchmark • Updated • 17.6k • 922k • 1.31k
How to use samhitha2601/llama3.2-3b-ppo with Transformers:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("samhitha2601/llama3.2-3b-ppo")
model = AutoModelForCausalLM.from_pretrained("samhitha2601/llama3.2-3b-ppo")This model is a actor checkpoint from Llama 3.2 3B fine-tuned on GSM8K using PPO (Proximal Policy Optimization) with the veRL framework.
This checkpoint was trained using PPO on the GSM8K dataset to improve mathematical reasoning capabilities. The model was optimized using reward-based learning to generate more accurate step-by-step solutions to math word problems.
The checkpoint was automatically selected using best-of-n evaluation across multiple training steps, ensuring optimal performance.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained(
"samhitha2601/llama3.2-3b-ppo",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("samhitha2601/llama3.2-3b-ppo")
# Example GSM8K problem
prompt = """Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Answer: Let's solve this step by step:"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
messages = [
{"role": "user", "content": "Solve this math problem: If a train travels 60 miles per hour for 2.5 hours, how far does it travel?"}
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
This model has been trained with PPO to maximize reward on GSM8K problems, showing improved:
If you use this model, please cite:
@misc{llama32-gsm8k-ppo,
title={Llama 3.2 3B Fine-tuned on GSM8K with PPO},
author={Your Name},
year={2025},
howpublished={\url{https://huggingface.co/samhitha2601/llama3.2-3b-ppo}},
}
Base model
meta-llama/Llama-3.2-3B