Reasoning Trace Distillation: Preserving Multi-Step Logic in Compressed Student Models

Reasoning Trace Distillation: Preserving Multi-Step Logic in Compressed Student Models

Knowledge distillation remains the dominant paradigm for compressing large language models into deployable student models. Yet a persistent failure mode undermines its reliability for the most consequential applications. When a 70-billion-parameter teacher is distilled into an 8-billion-parameter student using standard logit matching, the student often matches teacher perplexity within three percent. On standard benchmarks, the distillation appears successful. On tasks requiring multi-step reasoning, however, the same student exhibits catastrophic performance degradation that far exceeds what the perplexity gap would predict.

We have investigated this failure mode systematically at KriraAI and found that the root cause is a fundamental mismatch between what standard knowledge distillation reasoning objectives optimise and what multi-step reasoning requires. Standard KD optimises for pointwise distributional alignment at each token position independently. Multi-step reasoning is a sequential process where each step's validity depends on prior conclusions. A student achieving 92 percent step-level accuracy still fails 28 percent of four-step problems because errors compound multiplicatively. Standard KD has no mechanism to penalise this compounding.

Our research introduces Reasoning Trace Distillation (RTD), a framework that decomposes teacher chain of thought into atomic steps, aligns student representations at step boundaries, and introduces a reasoning coherence loss function that penalises transitions deviating from the teacher's reasoning trajectory. With curriculum-scheduled training progressing from single-step to multi-step problems, RTD improves full-chain accuracy on four-step-and-above problems by 18.7 percentage points over standard KD when distilling from a 70B teacher to an 8B student. This blog presents the full methodology, experiments, and findings.

The Reasoning Distillation Gap: Why Standard KD Fails

To understand why chain of thought distillation fidelity collapses on multi-step tasks, we must examine how reasoning errors propagate in distilled models. When a student processes a four-step reasoning chain, its hidden state representations progressively diverge from the teacher's trajectory. We term this phenomenon coherence drift. After step one, the student's hidden state aligns well with the teacher's. By step four, the accumulated drift pushes the student's state into a representational region where the correct next-step transition is inaccessible.

We measured this drift by computing centred kernel alignment between teacher and student hidden states at each step boundary across 3,200 reasoning chains from GSM8K. At step one, CKA similarity averages 0.89. At step two, 0.81. At step three, 0.72. At step four, 0.58. This monotonic degradation explains the compounding failure. The student is not making independent errors but progressively losing alignment with the teacher's reasoning trajectory.

Existing approaches to improving knowledge distillation reasoning do not address this mechanism directly. FitNets-style intermediate matching aligns representations at fixed network layers rather than reasoning step boundaries, applying a spatial alignment strategy to what is fundamentally a temporal coherence problem. The teacher's layer 16 representations do not correspond to any particular reasoning step, so aligning at this layer provides no guarantee of step-level coherence. Chain of thought fine-tuning, where the student trains on teacher-generated reasoning traces, teaches the surface form of reasoning without constraining internal representation dynamics. The student learns to produce text that looks like reasoning without being required to maintain the internal state trajectory that makes reasoning reliable.

Process reward models score individual reasoning steps but evaluate each step in isolation, providing no gradient signal for the coherence between consecutive steps. Two individually high-scoring steps can still represent a coherence-breaking transition if the student's internal state drifted between them. Step-level distillation losses improve individual step quality but treat each step independently, leaving the multi-step reasoning distillation compounding problem entirely unsolved.

The deployment consequence is severe and increasingly relevant. Organisations distill large models to reduce inference costs and latency, but if the distilled model fails on complex reasoning while handling simple queries well, the system becomes unreliable in a particularly dangerous way. Failure modes that correlate with query difficulty mean the system fails precisely when the stakes are highest and when the user most needs reliable output.

Reasoning Trace Distillation: Methodology

Reasoning Trace Distillation: Methodology

Core Insight

The fundamental insight behind RTD is that reasoning capability resides not in individual token predictions but in representation trajectory. A model reasons well when its hidden state follows a coherent path through representation space, where each transition enables the correct subsequent step. Standard KD optimises positional accuracy. RTD optimises trajectory fidelity. This shift changes what the student learns from producing correct tokens to maintaining a reasoning-compatible internal state across steps.

[Numbered List Image Title: Core Components of Reasoning Trace Distillation 01: Teacher Reasoning Decomposition 02: Step-Boundary Representation Alignment 03: Reasoning Coherence Loss 04: Curriculum-Scheduled Training 05: Coherence-Aware Evaluation]

Teacher Reasoning Decomposition

RTD first decomposes teacher reasoning chains into atomic steps with annotated boundaries. A step boundary is the token position where one logical deduction concludes and the next begins. We implement detection using a lightweight classifier trained on 4,500 manually annotated chains, operating on teacher hidden states to identify positions where the representation shifts direction discontinuously. The classifier achieves 94.1 percent F1 on held-out annotations. At each boundary, we extract the teacher's hidden state as a snapshot, producing a sequence h_T_1 through h_T_k for a k-step chain.

Step-Boundary Representation Alignment

The second component aligns student hidden states with teacher states at step boundaries rather than every token. The loss is L_align = (1/k) * sum(1 - cos(f(h_S_i), h_T_i)), where f is a learned projection from student to teacher representation space. Critically, we do not constrain the student between boundaries, giving it freedom to develop its own computation strategy for each step while requiring agreement at step conclusions.

The Reasoning Coherence Loss

The reasoning coherence loss function is our primary contribution. While boundary alignment ensures the student reaches correct representational waypoints, it does not constrain the transitions between them. A student could reach the correct state at step boundary two through an entirely different internal path than the teacher took, and this alternative path might not generalise to step boundary three. The coherence loss directly addresses this gap.

We define step transition vectors as delta_i = h_{i+1} - h_i, representing the direction and magnitude of representational change between consecutive steps. The coherence loss is formulated as L_coherence = (1/(k-1)) * sum(1 - cos(g(delta_S_i), delta_T_i))^2, where g is a learned projection and the squaring amplifies penalties for large deviations while tolerating small ones. We chose squaring over absolute value because empirically it produced more stable training dynamics, avoiding the gradient discontinuity at zero that absolute value introduces.

This loss encodes the inductive bias that the direction of representational change between steps carries information about the logical relationship between those steps. A transition moving in the teacher's direction maintains reasoning coherence. One drifting orthogonally is losing the logical thread. Our experiments confirm this assumption holds quantitatively. Transition vector cosine similarity between teacher and student predicts downstream step correctness with AUROC 0.83, a strong signal that validates the core hypothesis of this work.

Curriculum-Scheduled Training

We schedule training from one-step problems to progressively longer chains. Each stage trains until coherence loss on that stage's validation set falls below a teacher-calibrated threshold. This curriculum is essential, not merely convenient. Optimising coherence loss on four-step problems from the start fails because early-training representations are too poorly aligned for transition vectors to carry meaningful directional signal.

The total objective is L_RTD = L_KD + 0.3 * L_align + 0.5 * L_coherence, with the coherence loss receiving higher weight, reflecting its greater contribution confirmed through grid search on held-out data.

Experimental Design

Datasets and Benchmarks

We evaluated RTD across four benchmarks testing different multi-step reasoning types.

  • GSM8K: 1,319 math problems requiring two to eight reasoning steps, serving as the primary sequential reasoning benchmark.

  • MATH (difficulty 3-5): 2,100 competition problems requiring formal mathematical reasoning with non-trivial intermediate deductions.

  • StrategyQA: 2,290 multi-hop commonsense questions testing generalisation beyond mathematical domains.

  • ARC-Challenge: 1,172 science questions requiring multi-step causal reasoning.

Baselines and Configuration

We compared RTD against four baselines.

  • Standard KD: Logit matching with temperature 4.0 and alpha 0.7.

  • Intermediate KD: FitNets-style layer matching at four evenly spaced teacher layers.

  • CoT Fine-tuning: Student trained on 50,000 teacher-generated reasoning traces.

  • Step-Level Reward Distillation: Student trained with per-step reward model feedback.

All experiments used Llama-3-70B-Instruct as teacher and Llama-3-8B-Instruct as primary student, with additional 1B student experiments. Training ran on 8 NVIDIA A100 80GB GPUs, batch size 64, learning rate 2e-5 with cosine annealing, 40,000 total steps across curriculum stages.

Results and Analysis

Results and Analysis

Main Findings

RTD achieves its largest gains on four-step-and-above problems. On GSM8K four-plus-step problems, RTD reaches 71.3 percent full-chain accuracy versus 52.6 percent for standard KD, an 18.7 point improvement. On MATH difficulty 4-5, RTD achieves 38.4 percent versus 29.1 percent. On StrategyQA, the gain is 4.8 points. On ARC-Challenge, 6.2 points.

The gains scale with reasoning depth. On two-step GSM8K problems, RTD improves by only 1.9 points. On three-step problems, 8.3 points. On five-plus-step problems, 23.1 points. This confirms RTD addresses compounding coherence drift rather than providing general quality improvement. The student model reasoning capability for simple problems was already well-served by standard KD.

Ablation Studies

We isolated each component's contribution through systematic ablation on GSM8K and MATH.

  • Removing coherence loss (alignment only): four-step improvement drops from 18.7 to 7.1 points. Coherence loss accounts for 62 percent of RTD's total gain.

  • Removing curriculum (all depths simultaneously): improvement drops to 11.4 points. Curriculum scheduling contributes 24 percent of the gain.

  • Removing alignment (coherence only): improvement drops to 13.2 points. Alignment provides the foundation contributing 14 percent.

An unexpected finding emerged on simple problems. The 8B RTD student achieved 89.7 percent on two-step GSM8K problems, compared to 88.4 percent for the 70B teacher. The student outperformed the teacher by 1.3 points. We attribute this to the coherence loss regularising reasoning trajectories, eliminating variance that causes even the teacher occasional simple-problem failures. This finding reproduced across three independent runs.

Failure Cases

RTD does not improve problems requiring creative insight or non-decomposable leaps of logic. On MATH geometry problems requiring spatial insight, RTD's improvement was only 0.8 points versus 5.7 on algebraic problems. The sequential step assumption is a genuine limitation. RTD also showed diminished returns for the 1B student, where four-step improvement dropped to 9.3 points. Below a capacity threshold, students lack representational bandwidth for coherent trajectories regardless of training objective.

Discussion and Implications

The most significant implication is that perplexity and step-level accuracy are insufficient metrics for evaluating multi-step reasoning distillation. Standard KD achieves 90.4 percent step accuracy on GSM8K but only 52.6 percent chain accuracy on four-step problems. This 37.8 point gap represents compounding coherence drift that standard metrics completely miss. The field's current practice of evaluating distillation through perplexity systematically overestimates distilled model reasoning capability. Any organisation deploying a distilled model for reasoning-intensive workloads should evaluate chain-level accuracy stratified by reasoning depth, not aggregate metrics that average away the failure mode.

The finding that coherence loss contributes 62 percent of improvement while boundary alignment contributes 14 percent reveals something fundamental about how reasoning works in neural networks. Maintaining correct transitions between representations matters substantially more than reaching correct intermediate representations. A student that arrives at slightly different intermediate states but follows the same reasoning direction through representation space outperforms one that matches intermediate states precisely but drifts between them. This parallels neuroscience findings on motor learning where trajectory dynamics outweigh endpoint accuracy, and it suggests reasoning in transformers is better understood as a dynamical process than a sequence of independent computations.

The counterintuitive finding that RTD students outperform the teacher on two-step problems by 1.3 points deserves particular attention. The coherence loss appears to act as a regulariser that tightens the student's reasoning trajectories, eliminating the variance that causes even capable teacher models to occasionally err on simple problems. This suggests that explicit trajectory constraints may benefit even large models, and that self-distillation with coherence objectives could improve teacher model reasoning without any compression.

For practitioners, KriraAI has developed diagnostic tooling around coherence drift measurement. If a distilled model performs well on simple queries but fails unpredictably on complex ones, measuring CKA at reasoning step boundaries reveals whether coherence drift is responsible. In our enterprise deployments, the coherence drift metric predicts production failure rates on reasoning-heavy workloads with high reliability. As enterprise AI moves toward agentic multi-step workflows, the reliability of compressed reasoning becomes a deployment-critical concern that trajectory-aware distillation directly addresses.

Limitations and Future Work

The step-boundary classifier was trained on mathematical and logical reasoning traces. Its performance on other modalities, including moral reasoning, strategic planning, or hypothesis generation, is untested. The sequential decomposition assumption limits RTD's applicability to reasoning types that are genuinely sequential. Abductive reasoning and creative problem solving may require entirely different distillation strategies.

RTD adds approximately 60 percent training time over standard KD through curriculum scheduling, step detection, and transition computation. This overhead may be prohibitive for compute-constrained organisations. We are investigating distilled coherence losses that approximate transition alignment without explicit step detection. KriraAI is also studying the interaction between RTD and quantisation, since 4-bit PTQ of an RTD-trained student reduced the multi-step improvement from 18.7 to 11.2 points, retaining most but not all gains. Joint optimisation across compression axes remains an open problem.

Conclusion

This research makes three contributions to knowledge distillation reasoning for multi-step tasks. First, we identified coherence drift as the mechanism behind disproportionate reasoning failure in distilled models, showing CKA alignment degrades from 0.89 to 0.58 across four steps. Second, we introduced Reasoning Trace Distillation, whose reasoning coherence loss function accounts for 62 percent of improvement by optimising trajectory fidelity rather than pointwise matching. Third, we demonstrated 18.7 percentage points of chain accuracy improvement on four-step problems and discovered that the coherence loss regularises student reasoning enough to outperform the teacher on simpler problems by 1.3 points.

These findings suggest model compression for reasoning applications requires a shift from output-matching to trajectory-matching objectives. As deployment increasingly demands reliable multi-step reasoning in compact models, coherence of the reasoning process becomes the critical optimisation target.

This work represents one contribution within KriraAI's broader research programme on deploying reliable AI reasoning at production scale. We conduct applied research to solve the technical problems preventing trustworthy enterprise AI, not merely to advance benchmarks. We invite researchers working on knowledge distillation, chain of thought distillation fidelity, and model compression to engage with these findings, replicate the approach, and explore collaboration on the open problems this work raises.

FAQs

Standard knowledge distillation optimises for token-level distributional alignment, treating each output position independently. Multi-step reasoning requires that the model's internal state after each step provides the correct foundation for the subsequent step, a sequential dependency that pointwise losses cannot capture. Our measurements show centred kernel alignment between teacher and student degrades monotonically across steps, from 0.89 at step one to 0.58 at step four. This coherence drift causes errors to compound, so a student with 92 percent step accuracy achieves only 72 percent chain accuracy on four-step problems. Perplexity averages across all tokens, masking this compounding because most tokens are not at critical step boundaries.

The reasoning coherence loss function penalises the student when its representational transitions between reasoning steps deviate from the teacher's transitions. Unlike standard representation matching that aligns absolute positions of hidden states at fixed layers or tokens, the coherence loss operates on transition vectors defined as differences between hidden states at consecutive step boundaries. It is formulated as squared cosine distance between projected student and teacher transition vectors, averaged across all transitions. This encodes the inductive bias that the direction of representational change between steps carries information about logical relationships. Transition vector cosine similarity predicts downstream step correctness with AUROC 0.83, validating this assumption.

RTD adds approximately 60 percent to total training time from three sources: curriculum scheduling requiring multiple convergence stages, step-boundary detection adding a forward pass through a lightweight classifier per batch, and transition vector computation at step boundaries. However, RTD adds zero inference overhead because it modifies only the training objective without changing student architecture. The student produced by RTD has identical inference cost to one from standard KD, making the training overhead a one-time investment amortised over deployment lifetime. For organisations prioritising inference efficiency, this tradeoff is favourable.

RTD works best for reasoning decomposable into sequential atomic steps where each builds on the previous. Mathematical reasoning, multi-hop factual reasoning, and formal logical inference fit this pattern well. Our experiments showed strong gains on GSM8K, MATH, and StrategyQA. However, RTD shows diminished gains on non-sequential reasoning such as geometry requiring spatial insight, creative problem solving, or abductive reasoning. The step decomposition assumption is a structural limitation. On MATH geometry problems, RTD improved by only 0.8 points compared to 5.7 on algebraic problems. Alternative strategies not assuming sequential structure would be needed for these reasoning types.

RTD produces a standard-architecture student model compatible with other compression techniques. However, quantisation applied after RTD training can re-introduce coherence drift by perturbing hidden states that RTD aligned. In initial experiments, 4-bit post-training quantisation of an RTD-trained 8B student reduced multi-step reasoning improvement from 18.7 to 11.2 points, retaining substantial but incomplete gains. KriraAI is investigating quantisation-aware training combined with RTD objectives to mitigate this interaction. Structured pruning is likely more compatible than quantisation because it removes parameters rather than perturbing values, though this remains an open empirical question.

Divyang Mandani

Divyang Mandani

CEO

Divyang Mandani is the CEO of KriraAI, driving innovative AI and IT solutions with a focus on transformative technology, ethical AI, and impactful digital strategies for businesses worldwide.

April 15, 2026

Ready to Write Your Success Story?

Do not wait for tomorrow; lets start building your future today. Get in touch with KriraAI and unlock a world of possibilities for your business. Your digital journey begins here - with KriraAI, where innovation knows no bounds. 🌟