Skip to content

vb-creator/TRtutor

Repository files navigation

TRTutor: Talker-Reasoner Math Tutor

This repository contains code to train and deploy a Socratic 'talker-reasoner' framework to implement a math tutor. It includes,

  • Dataset Generation: Scripts to generate training dataset with Talker-reasoner enhanced dialogues from raw mathdial conversation samples.
  • Reasoner Context Generation: Code to generate chain of thought, belief state, and final answers for each problem using SOTA reasoner models like o3-mini, deepseek-r1.
  • Talker Fine‑tuning: finetuning_sft_talker.py to run supervised fine‑tuning with LoRA adapters.
  • Inference: inference/talker_predict.py to generate tutor responses at inference time.

Installation

  1. Create a Python environment (poetry) with Python 3.9+.

  2. Install dependencies (first install poetry)

    poetry install
  3. Configure W&B to track model results (optional):

    wandb login
    export WANDB_PROJECT="talker-tutor"

Dataset Generation

Use dataset_generation/enhanced_dialogue_generation.py to generate reasoner_context and personalized conversation dataset using openAI models.

python dataset_generation/enhanced_dialogue_generation.py \
  --input_dir data/mathdial_df.pkl \
  --output_json tr_data/train_mathdial.json

Use dataset_generation/prepare_sft_data.py to format the dataset for SFT,

  • instruction: the talker prompt template
  • input: JSON with student_persona, reasoner_context, conversation_history
  • output: the teacher’s next response

Reasoner Context Generation

Compute chain of thought, belief state, and final answer using OpenAI (API key), Deepseek-R1 (using Fireworks API key):

python scripts/reasoner_context_generation.py \
  --problems data/raw/problems.json \
  --solutions data/raw/solutions.json \
  --output_json data/tr_data/train_mathdial.json

This enriches talker's shared context using reasoner_context.


Talker Fine‑tuning

Run finetuning/sft_talker.py with torchrun below command requires access to 2 GPUs:

CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc-per-node=2 finetuning/sft_talker.py \
  --model_name Qwen/Qwen2.5-7B-Instruct \
  --bf16 \
  --data_path data/tr_data/train_mathdial.json \
  --output_dir finetuned_models/trt_qwen_2point5_7b \
  --cache_dir data/cache \
  --model_max_length 2048 \
  --per_device_train_batch_size 2 \
  --gradient_accumulation_steps 16 \
  --learning_rate 2e-5 \
  --save_strategy steps \
  --save_steps 50 \
  --save_total_limit 5 \
  --deepspeed deepspeed/config.json

Checkpointing: Use --resume_from_checkpoint <path> to resume from a saved checkpoint.


Inference

Generate tutor responses with inference/talker_predict.py:

python inference/talker_predict.py \
  --adapter_path finetuned_models/trt_qwen_2point5_7b \
  --test_json data/tr_data/test_mathdial.json \
  --output_json data/tr_data/predictions.json

This will load the saved model checkpoint and generate next best teacher dialogue for the given conversation history.


Prompts

  • talker_prompt.txt: details the Socratic instruction template used by the talker.
  • reasoner_prompt.txt: defines how to generate chain of thought and belief states from given info.

Data and Results

The datasets used for finetuning, and the inference results and conversation simulations generated for LLM‑based evaluation can be found here.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors