gsm8k.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import logging
  2. import os
  3. import re
  4. import numpy as np
  5. from utils.data_utils import load_jsonl, save_jsonl
  6. INVALID_ANS = '[invalid]'
  7. def extract_answer(completion):
  8. def _get_last_digit(s):
  9. _PAT_LAST_DIGIT = re.compile(r'(?<=(\s|[\$%#{]))([+-])?(?=(\S))(0|([1-9](\d*|\d{0,2}(,\d{3})*)))?(\.\d*[1-9])?(?=(\s|[.,}]|$))')
  10. match = list(_PAT_LAST_DIGIT.finditer(s))
  11. if match:
  12. last_digit = match[-1].group().replace(',', '').replace('+', '')
  13. else:
  14. last_digit = None
  15. logging.warning(f'No digits found in {s!r}')
  16. return last_digit
  17. job_gen = completion.strip('.').replace('\n', '\\n')
  18. last_digit = _get_last_digit(job_gen)
  19. if last_digit:
  20. return eval(last_digit)
  21. else:
  22. return INVALID_ANS
  23. def is_correct(completion, answer):
  24. gold = extract_answer(answer)
  25. assert gold != INVALID_ANS, 'No ground truth answer found in the document.'
  26. return extract_answer(completion) == gold
  27. def eval_gsm8k_acc(output_fname):
  28. data_list = load_jsonl(output_fname)
  29. acc_res = [item['acc'] for item in data_list]
  30. logging.info('='*60)
  31. logging.info('{:^60}'.format('Math Acc.'))
  32. logging.info('='*60)
  33. logging.info('Total num={:.2f}'.format(len(acc_res)))
  34. logging.info('Right num={:.2f}'.format(np.sum(acc_res)))
  35. logging.info('Zero-shot Acc={:.2f}'.format(np.mean(acc_res)*100))
  36. error_data_list = [item for item in data_list if not item['acc']]
  37. error_data_output_fname = os.path.splitext(output_fname)[0] + '_gsm8k_error.jsonl'
  38. save_jsonl(error_data_list, error_data_output_fname)
  39. return {'math': np.mean(acc_res)*100}