skip to main content
10.5555/3666122.3669306guideproceedingsArticle/Chapter ViewAbstractPublication PagesnipsConference Proceedingsconference-collections
research-article

Training chain-of-thought via latent-variable inference

Published: 30 May 2024 Publication History

Abstract

Large language models (LLMs) solve problems more accurately and interpretably when instructed to work out the answer step by step using a "chain-of-thought" (CoT) prompt. One can also improve LLMs' performance on a specific task by supervised fine-tuning, i.e., by using gradient ascent on some tunable parameters to maximize the average log-likelihood of correct answers from a labeled training set. Naively combining CoT with supervised tuning requires supervision not just of the correct answers, but also of detailed rationales that lead to those answers; these rationales are expensive to produce by hand. Instead, we propose a fine-tuning strategy that tries to maximize the marginal log-likelihood of generating a correct answer using CoT prompting, approximately averaging over all possible rationales. The core challenge is sampling from the posterior over rationales conditioned on the correct answer; we address it using a simple Markov-chain Monte Carlo (MCMC) expectation-maximization (EM) algorithm inspired by the self-taught reasoner (STaR), memoized wake-sleep, Markovian score climbing, and persistent contrastive divergence. This algorithm also admits a novel control-variate technique that drives the variance of our gradient estimates to zero as the model improves. Applying our technique to GSM8K and the tasks in BIG-Bench Hard, we find that this MCMC-EM fine-tuning technique typically improves the model's accuracy on held-out examples more than STaR or prompt-tuning with or without CoT.

References

[1]
Anil, R., Dai, A. M., Firat, O., Johnson, M., Lepikhin, D., Passos, A., Shakeri, S., Taropa, E., Bailey, P., Chen, Z., et al. Palm 2 technical report. arXiv preprint arXiv:2305.10403, 2023.
[2]
Bishop, C. M. Pattern recognition and machine learning. Springer, 2006.
[3]
Bornschein, J. and Bengio, Y. Reweighted wake-sleep, 2015.
[4]
Burda, Y., Grosse, R., and Salakhutdinov, R. Importance weighted autoencoders. arXiv preprint arXiv:1509.00519, 2015.
[5]
Chung, H. W., Hou, L., Longpre, S., Zoph, B., Tay, Y., Fedus, W., Li, E., Wang, X., Dehghani, M., Brahma, S., et al. Scaling instruction-finetuned language models. arXiv preprint arXiv:2210.11416, 2022.
[6]
Cobbe, K., Kosaraju, V., Bavarian, M., Chen, M., Jun, H., Kaiser, L., Plappert, M., Tworek, J., Hilton, J., Nakano, R., Hesse, C., and Schulman, J. Training verifiers to solve math word problems. arXiv preprint arXiv:2110.14168, 2021.
[7]
Cover, T. M. Elements of information theory. John Wiley & Sons, 1999.
[8]
Creswell, A., Shanahan, M., and Higgins, I. Selection-Inference: Exploiting large language models for interpretable logical reasoning. arXiv preprint arXiv:2205.09712, May 2022.
[9]
Dohan, D., Xu, W., Lewkowycz, A., Austin, J., Bieber, D., Lopes, R. G., Wu, Y., Michalewski, H., Saurous, R. A., Sohl-dickstein, J., Murphy, K., and Sutton, C. Language model cascades, 2022.
[10]
Geman, S. and Geman, D. Stochastic relaxation, Gibbs distributions, and the Bayesian restoration of images. IEEE Transactions on pattern analysis and machine intelligence, 6:721-741, 1984.
[11]
Hastings, W. K. Monte Carlo sampling methods using Markov chains and their applications. Biometrika, 57(1):97-109, 04 1970. ISSN 0006-3444.
[12]
Hewitt, L. B., Le, T. A., and Tenenbaum, J. B. Learning to learn generative programs with memoised wake-sleep. In Uncertainty in Artificial Intelligence, 2020.
[13]
Hol, J. D., Schon, T. B., and Gustafsson, F. On resampling algorithms for particle filters. In 2006 IEEE Nonlinear Statistical Signal Processing Workshop, pp. 79-82, 2006.
[14]
Jouppi, N., Kurian, G., Li, S., Ma, P., Nagarajan, R., Nai, L., Patil, N., Subramanian, S., Swing, A., Towles, B., et al. Tpu v4: An optically reconfigurable supercomputer for machine learning with hardware support for embeddings. In Proceedings of the 50th Annual International Symposium on Computer Architecture, pp. 1-14, 2023.
[15]
Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. In Bengio, Y. and LeCun, Y. (eds.), 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings, 2015. URL http://arxiv.org/abs/1412.6980.
[16]
Kingma, D. P. and Welling, M. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
[17]
Kojima, T., Gu, S. S., Reid, M., Matsuo, Y., and Iwasawa, Y. Large language models are Zero-Shot reasoners. arXiv preprint arXiv:2205.11916, May 2022.
[18]
Le, T. A., Kosiorek, A. R., Siddharth, N., Teh, Y. W., and Wood, F. Revisiting reweighted wake-sleep for models with stochastic control flow. In Uncertainty in Artificial Intelligence, 2019.
[19]
Lester, B., Al-Rfou, R., and Constant, N. The power of scale for parameter-efficient prompt tuning. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pp. 3045-3059, 2021.
[20]
Lewkowycz, A., Andreassen, A., Dohan, D., Dyer, E., Michalewski, H., Ramasesh, V., Slone, A., Anil, C., Schlag, I., Gutman-Solo, T., Wu, Y., Neyshabur, B., Gur-Ari, G., and Misra, V. Solving quantitative reasoning problems with language models, 2022.
[21]
Lievin, V. Deep Latent Variable Models for Natural Language Processing. PhD thesis, 2022.
[22]
Loshchilov, I. and Hutter, F. SGDR: Stochastic gradient descent with warm restarts. In International Conference on Learning Representations, 2017. URL https://openreview.net/forum?id=Skq89Scxx.
[23]
Murphy, K. P. Probabilistic Machine Learning: An introduction. MIT Press, 2022. URL probml.ai.
[24]
Murray, I. and Salakhutdinov, R. Notes on the kl-divergence between a markov chain and its equilibrium distribution. preprint, 2008.
[25]
Naesseth, C., Lindsten, F., and Blei, D. Markovian score climbing: Variational inference with kl (pll q). Advances in Neural Information Processing Systems, 33:15499-15510, 2020.
[26]
Nielsen, S. F. The stochastic EM algorithm: Estimation and asymptotic results. Bernoulli, 6(3): 457-489, June 2000.
[27]
Nye, M., Andreassen, A. J., Gur-Ari, G., Michalewski, H., Austin, J., Bieber, D., Dohan, D., Lewkowycz, A., Bosma, M., Luan, D., Sutton, C., and Odena, A. Show your work: Scratchpads for intermediate computation with language models. arXiv preprint arXiv:2112.00114, November 2021.
[28]
OpenAI. Gpt-4 technical report, 2023.
[29]
Owen, A. and Zhou, Y. Safe and effective importance sampling. Journal of the American Statistical Association, 95(449):135-143, 2000.
[30]
Parisi, A., Zhao, Y., and Fiedel, N. Talm: Tool augmented language models, 2022.
[31]
Rajani, N. F., McCann, B., Xiong, C., and Socher, R. Explain yourself! leveraging language models for commonsense reasoning. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pp. 4932-4942, Florence, Italy, July 2019. Association for Computational Linguistics. URL https://aclanthology.org/P19-1487.
[32]
Roeder, G., Wu, Y., and Duvenaud, D. K. Sticking the landing: Simple, lower-variance gradient estimators for variational inference. In Guyon, I., Luxburg, U. V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017. URL https://proceedings.neurips.cc/paper_files/paper/2017/file/e91068fff3d7fa1594dfdf3b4308433a-Paper.pdf.
[33]
Schick, T., Dwivedi-Yu, J., Dessì, R., Raileanu, R., Lomeli, M., Zettlemoyer, L., Cancedda, N., and Scialom, T. Toolformer: Language models can teach themselves to use tools, 2023.
[34]
Shinn, N., Labash, B., and Gopinath, A. Reflexion: an autonomous agent with dynamic memory and self-reflection, 2023.
[35]
Shwartz, V., West, P., Le Bras, R., Bhagavatula, C., and Choi, Y. Unsupervised common-sense question answering with self-talk. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), pp. 4615-4629, Online, November 2020. Association for Computational Linguistics. URL https://aclanthology.org/2020.emnlp-main.373.
[36]
Suzgun, M., Scales, N., Schärli, N., Gehrmann, S., Tay, Y., Chung, H. W., Chowdhery, A., Le, Q. V., Chi, E. H., Zhou, D., and Wei, J. Challenging BIG-Bench tasks and whether Chain-of-Thought can solve them. "arXiv preprint arXiv:2210.09261", October 2022a.
[37]
Suzgun, M., Scales, N., Schärli, N., Gehrmann, S., Tay, Y., Chung, H. W., Chowdhery, A., Le, Q. V., Chi, E. H., Zhou, D., et al. Challenging big-bench tasks and whether chain-of-thought can solve them. arXiv preprint arXiv:2210.09261, 2022b.
[38]
Talmor, A., Herzig, J., Lourie, N., and Berant, J. CommonsenseQA: A question answering challenge targeting commonsense knowledge. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pp. 4149-4158, Minneapolis, Minnesota, June 2019. Association for Computational Linguistics. URL https://aclanthology.org/N19-1421.
[39]
Tieleman, T. Training restricted boltzmann machines using approximations to the likelihood gradient. In Proceedings of the 25th international conference on Machine learning, pp. 1064-1071, 2008.
[40]
Tierney, L. Markov chains for exploring posterior distributions. Annals of Statistics, 22(4):1701-1728, December 1994.
[41]
Tucker, G., Mnih, A., Maddison, C. J., Lawson, J., and Sohl-Dickstein, J. Rebar: Low-variance, unbiased gradient estimates for discrete latent variable models. Advances in Neural Information Processing Systems, 30, 2017.
[42]
Turpin, M., Michael, J., Perez, E., and Bowman, S. R. Language models don't always say what they think: Unfaithful explanations in chain-of-thought prompting. arXiv preprint arXiv:2305.04388, 2023.
[43]
Uesato, J., Kushman, N., Kumar, R., Song, F., Siegel, N., Wang, L., Creswell, A., Irving, G., and Higgins, I. Solving math word problems with process-and outcome-based feedback. arXiv preprint arXiv:2211.14275, 2022.
[44]
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. Attention is all you need. Advances in neural information processing systems, 30, 2017.
[45]
Wang, X., Wei, J., Schuurmans, D., Le, Q., Chi, E., Narang, S., Chowdhery, A., and Zhou, D. Self-Consistency improves chain of thought reasoning in language models. arXiv preprint arXiv:2203.11171, March 2022a.
[46]
Wang, X., Wei, J., Schuurmans, D., Le, Q., Chi, E., and Zhou, D. Self-consistency improves chain of thought reasoning in language models. arXiv preprint arXiv:2203.11171, 2022b.
[47]
Wei, J., Wang, X., Schuurmans, D., Bosma, M., Chi, E., Le, Q., and Zhou, D. Chain of thought prompting elicits reasoning in large language models. "arXiv preprint arXiv:2201.11903", January 2022.
[48]
Yao, S., Zhao, J., Yu, D., Du, N., Shafran, I., Narasimhan, K., and Cao, Y. React: Synergizing reasoning and acting in language models, 2023.
[49]
Ye, X. and Durrett, G. Explanation selection using unlabeled data for In-Context learning. "'arXiv preprint arXiv:2302.04813", February 2023.
[50]
Zelikman, E., Wu, Y., and Goodman, N. D. Star: Bootstrapping reasoning with reasoning. arXiv preprint arXiv:2203.14465, 2022.
[51]
Zhou, D., Schärli, N., Hou, L., Wei, J., Scales, N., Wang, X., Schuurmans, D., Bousquet, O., Le, Q., and Chi, E. Least-to-Most prompting enables complex reasoning in large language models. "arXiv preprint arXiv:2205.10625", May 2022.

Recommendations

Comments

Please enable JavaScript to view thecomments powered by Disqus.

Information & Contributors

Information

Published In

cover image Guide Proceedings
NIPS '23: Proceedings of the 37th International Conference on Neural Information Processing Systems
December 2023
80772 pages

Publisher

Curran Associates Inc.

Red Hook, NY, United States

Publication History

Published: 30 May 2024

Qualifiers

  • Research-article
  • Research
  • Refereed limited

Contributors

Other Metrics

Bibliometrics & Citations

Bibliometrics

Article Metrics

  • 0
    Total Citations
  • 0
    Total Downloads
  • Downloads (Last 12 months)0
  • Downloads (Last 6 weeks)0
Reflects downloads up to 04 Oct 2024

Other Metrics

Citations

View Options

View options

Media

Figures

Other

Tables

Share

Share

Share this Publication link

Share on social media