LEARNING MATH REASONING FROM SELF-SAMPLED CORRECT AND PARTIALLY-CORRECT SOLUTIONS

Abstract

Pretrained language models have shown superior performance on many natural language processing tasks, yet they still struggle at multi-step formal reasoning tasks like grade school math problems. One key challenge of finetuning them to solve such math reasoning problems is that many existing datasets only contain one reference solution for each problem, despite the fact that there are often alternative solutions resembling different reasoning paths to the final answer. This way, the finetuned models are biased towards the limited reference solutions, which limits their generalization to unseen examples. To mitigate this issue, we propose to let the model perform sampling during training and learn from both selfsampled fully-correct solutions, which yield the correct answer upon execution, and partially-correct solutions, whose intermediate state matches an intermediate state of a known correct solution. We show that our use of self-sampled correct and partially-correct solutions can benefit learning and help guide the sampling process, leading to more efficient exploration of the solution space. Additionally, we explore various training objectives to support learning from multiple solutions per example and find they greatly affect the performance. Experiments on two math reasoning datasets show the effectiveness of our method compared to learning from a single reference solution with MLE, where we improve PASS@100 from 35.5% to 44.5% for GSM8K, and 27.6% to 36.2% PASS@80 for MathQA. Such improvements are also consistent across different model sizes. Our code is available at https://github.com/microsoft/TraceCodegen.

1. INTRODUCTION

Recent progress on pretrained language models shows that they are able to achieve human-level performance on various natural language processing tasks with finetuning (Devlin et al., 2019; Brown et al., 2020; Raffel et al., 2020) . However, such models still lack the ability to perform multi-step math reasoning even for problems that are intended for grade-school students (Cobbe et al., 2021) . Current methods for solving math problems typically rely on generating solutions (a sequence of computation steps) and executing them to obtain the final answer (Cobbe et al., 2021; Austin et al., 2021; Chen et al., 2021a; Chowdhery et al., 2022) , as directly generating the final answer would require computational abilities that even the largest models do not possess (Brown et al., 2020; Chowdhery et al., 2022) . When finetuning such models on math reasoning, existing methods often rely on the MLE objective that aims to maximize the log-likelihood of the reference solution for each natural language input. However, in addition to the reference solution, there are often multiple correct solutions for each question, resembling alternative reasoning paths to the final answer. However, those alternative solutions are unseen during training, and this results in model overfitting: the model becomes overly confident in its predictions because it sees the same solution over multiple epochs of training (Bunel et al., 2018; Austin et al., 2021; Cobbe et al., 2021) . This leads to poor generalization on unseen inputs and is reflected by the low PASS@k performance, where the model is unable to predict the right answer even when allowed multiple attempts per question. To mitigate this issue, we propose learning from self-sampled solutions. Concretely, during training time, the model samples alternative solutions, and keeps track of all solutions that are semantically correct with respect to the gold execution result, and learns from all of these correct solutions as opposed to only from the reference. To further improve the effectiveness of learning from selfsampled solutions, we allow the model to learn from partially-correct solutions, whose intermediate states are consistent with intermediate states of known correct solutions. This new technique allows the model to maximally utilize the self-sampling and more efficiently explore the solution space. We also study various common loss functions for learning from multiple targets for a single natural language input, including augmented-MLE, Maximize Marginal Likelihood (MML) and β-smoothed MML (Guu et al., 2017) and find that their different gradient equations greatly affect the learning capabilities of the model. We perform experiments on two math reasoning tasks, namely MathQA-Python (Austin et al., 2021) and Grade-School-Math (GSM) (Cobbe et al., 2021) , and finetune GPT-Neo models (Black et al., 2021) to generate Python program as solutions from the problem description in natural language. Results show that learning from self-sampled solutions can improve the PASS@100 from 35.5% to 44.5% for GSM, and 27.6% to 36.2% for PASS@80 on a filtered version of MathQA-Python. 1 Moreover, we find that learning from partially-correct solutions generally improves performance over learning from just fully-correct solutions (e.g., +3.0% PASS@100 for GSM8K) as it guides the sampling process, discovering more alternative solutions for learning. Such performance boosts from our proposed methods are also consistent for different model sizes. Ablation on different loss functions shows that MLE-Aug loss is the most effective in learning from multiple targets and yields the most improvements over MLE loss.

2. OVERVIEW

Problem formulation. We consider the task of generating solutions from math problem descriptions in natural language (NL). Given an NL input x ∈ X and the executor E : Y → Z, the goal is to generate a solution y ∈ Y that executes to the expected answer z * ∈ Z, i.e., E(y) = z * . Standard approach and its limitation. The standard approach is to assume that we have a dataset of paired NL input x and reference solution y * . Most datasets typically only provide one reference solution for a particular NL input. Then, a parameterized model P θ is learned with the Maximum Likelihood Estimation (MLE) objective from the NL-Solution pair (x, y * ) as: L MLE (x, y * , P θ ) = -log P θ (y * |x) (1) The builtin assumption of using Eq. 1 for learning is that only the reference solution y * is correct. However, this assumption is clearly untrue for the math reasoning problem as typically multiple reasoning paths can achieve the correct final result. With only one reference solution as target for learning, Eq. 1 would encourage the model to put all probability mass on y * , which could easily lead to overfitting (Bunel et al., 2018; Austin et al., 2021; Cobbe et al., 2021) . Overview of our approach. While manually collecting additional reference solutions for each specification is a laborious process (Austin et al., 2021; Cobbe et al., 2021; Schuster et al., 2021) , in our work, we explore an alternate approach: where the model self-samples additional correct (or partially-correct) solutions and learns from them during training. Fig. 1 shows an example: for the question x, our model was able to self-sample an alternative solution ŷ that is different from the reference solution y * provided in the dataset. Looking at the intermediate states shown on the right, we can see that both these solutions execute to produce the sample desired output, i.e., ẑ = z * , as noted with solid red boxes. Taking this one step further, our approach can also identify partiallycorrect solutions from its samples. For example, on the bottom left, we show a sampled solution ŷ′ that is incorrect only because of an error in its last two steps. But we identify a prefix ŷ′ ≤5 of it as partially-correct because the intermediate state ŝ′ 5 for this prefix matches the intermediate state s *



of a known correct solution y * (noted as dashed red boxes) and yet syntactically different from y * . Based on these observations and intuitions, we introduce our approach in the following sections.1 We choose different k for evaluating PASS@k to be consistent with previous work.

