inference_and_execute.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import argparse
  2. import json
  3. import logging
  4. import os
  5. from parser import ReActParser
  6. import prettytable
  7. import tqdm
  8. from code_interpreter import code_interpreter
  9. from config import (get_model, get_react_parser, get_react_prompt,
  10. model_path_map)
  11. from datasets import load_dataset
  12. from metrics.code_execution import eval_code_execution_rate
  13. from metrics.gsm8k import eval_gsm8k_acc, is_correct
  14. from metrics.visualization import eval_visualization_acc
  15. from utils.code_utils import replace_upload_fname
  16. from utils.data_utils import load_jsonl
  17. logging.basicConfig(
  18. format='%(asctime)s - %(levelname)s - %(message)s',
  19. datefmt='%Y-%m-%d %H:%M:%S',
  20. level=logging.INFO,
  21. )
  22. WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace')
  23. os.makedirs(WORK_DIR, exist_ok=True)
  24. os.system(f'cp -r upload_file_clean {WORK_DIR}/upload_file')
  25. os.system('cp -r upload_file_clean ./upload_file')
  26. global_eval_result = {
  27. 'code_executability': {
  28. 'math': None,
  29. 'visualization': None,
  30. 'general': None,
  31. },
  32. 'code_correctness': {
  33. 'math': None,
  34. 'visualization-hard': None,
  35. 'visualization-easy': None,
  36. }
  37. }
  38. def llm_with_plugin(args, query, item=None, exec_limit=3):
  39. exec_count = 0
  40. # Build ReAct prompt
  41. upload_fname_list = item['input_file_path'] if item and 'input_file_path' in item else []
  42. lang = item['lang'] if item and 'lang' in item else 'en'
  43. react_prompt_obj = get_react_prompt(args.model, query, lang, upload_fname_list)
  44. planning_prompt = react_prompt_obj.build_prompt()
  45. # Execute the code when providing the first action in the query
  46. if '<|im_start|>' in query:
  47. _, prepend_code, __ = ReActParser().parse_latest_plugin_call(query)
  48. prepend_code = replace_upload_fname(prepend_code, upload_fname_list)
  49. call_plugin(_, [prepend_code], clear=(exec_count == 0))
  50. exec_count += 1
  51. exec_limit += 1
  52. # Inference and execute
  53. text = ''
  54. while exec_count < exec_limit:
  55. stop_words_list = react_prompt_obj.get_stop_words_list()
  56. output = text_completion(args.llm, planning_prompt + text, stop_words=stop_words_list)
  57. if args.gen_only:
  58. text += output
  59. break
  60. react_parser = get_react_parser(args.model)
  61. action, action_input, output = react_parser.parse_latest_plugin_call(output)
  62. if action:
  63. action_input = replace_upload_fname(action_input, upload_fname_list)
  64. observation = call_plugin(action, [action_input], clear=(exec_count == 0))
  65. output += react_prompt_obj.build_observation(observation)
  66. text += output
  67. exec_count += 1
  68. if 'error:' in observation or 'Traceback' in observation:
  69. break
  70. else:
  71. text += output
  72. break
  73. return text
  74. def text_completion(llm, input_text, stop_words=[]):
  75. logging.info('Generating'.center(60, '='))
  76. logging.info('Input'.center(60, '-'))
  77. logging.info(input_text)
  78. output = llm.generate(input_text, stop_words)
  79. logging.info('Output'.center(60, '-'))
  80. logging.info(output)
  81. return output
  82. def call_plugin(plugin_name, plugin_args_list, clear=False):
  83. # Relax constraints on plugin name.
  84. logging.info('Call code interpreter'.center(60, '='))
  85. obs = code_interpreter(plugin_args_list, clear=clear)
  86. logging.info(obs)
  87. return obs
  88. def process_code_interpreter(item, writer):
  89. query = item['query']
  90. exec_limit = 3 if 'visualization' in item['tags'] else 1
  91. response = llm_with_plugin(args=args, query=query, item=item, exec_limit=exec_limit)
  92. item['gen'] = response
  93. writer.write(json.dumps(item, ensure_ascii=False) + '\n')
  94. writer.flush()
  95. def process_gsm8k(doc, writer):
  96. context = doc['question']
  97. completion = llm_with_plugin(args=args, query=context)
  98. acc = is_correct(completion, doc['answer'])
  99. doc['completion'] = completion
  100. doc['acc'] = acc
  101. writer.write(json.dumps(doc, ensure_ascii=False) + '\n')
  102. writer.flush()
  103. def sequential_processing(args, data_list, process_func, writer):
  104. for item in tqdm.tqdm(data_list):
  105. process_func(item, writer)
  106. process_func_map = {
  107. 'gsm8k': process_gsm8k,
  108. 'visualization': process_code_interpreter
  109. }
  110. def gather_eval_result(model_name):
  111. for metric in global_eval_result:
  112. logging.info(metric)
  113. table = prettytable.PrettyTable()
  114. table.field_names = ['model'] + list(global_eval_result[metric].keys())
  115. row_data = [model_name]
  116. for item in global_eval_result[metric].values():
  117. item = str(item) if not item else str(round(item, 2))
  118. row_data.append(item)
  119. table.add_row(row_data)
  120. logging.info('\n' + str(table))
  121. def eval_metrics(args, test_set, full_output_fname):
  122. # metrics
  123. assert os.path.exists(full_output_fname), f'Not Found File {full_output_fname}.'
  124. inference_res = load_jsonl(full_output_fname)
  125. assert len(inference_res) == len(test_set), f'There are still {len(test_set)-len(inference_res)} cases left.'
  126. abs_output_fname = os.path.join(os.path.dirname(os.path.abspath(__file__)), full_output_fname)
  127. if args.task == 'gsm8k':
  128. math_code_correctness = eval_gsm8k_acc(abs_output_fname)
  129. global_eval_result['code_correctness'].update(math_code_correctness)
  130. else:
  131. code_executability = eval_code_execution_rate(abs_output_fname, args.task, args.model)
  132. global_eval_result['code_executability'].update(code_executability)
  133. if args.task in ['all_ci', 'visualization'] and not args.eval_code_exec_only:
  134. visualization_code_correctness = eval_visualization_acc(abs_output_fname, args.model)
  135. global_eval_result['code_correctness'].update(visualization_code_correctness)
  136. def main(args):
  137. current_dir = os.getcwd()
  138. os.makedirs(args.output_path, exist_ok=True)
  139. full_output_fname = os.path.join(args.output_path, (args.output_fname or f'{args.task}_{args.model}_res.jsonl'))
  140. if not os.path.exists(full_output_fname):
  141. with open(full_output_fname, 'w'):
  142. logging.info(f'Create file {full_output_fname} done.')
  143. # build data
  144. if args.task == 'gsm8k':
  145. dataset = load_dataset('gsm8k', 'main')
  146. test_set = dataset['test']
  147. else:
  148. eval_data_path = os.path.join(args.input_path, args.input_fname)
  149. test_set = [item for item in load_jsonl(eval_data_path) if args.task in item['tags']]
  150. logging.info(f'Test set: {len(test_set)}')
  151. if args.eval_only:
  152. eval_metrics(args, test_set, full_output_fname)
  153. else:
  154. key = 'question' if args.task == 'gsm8k' else 'query'
  155. cache_question = [item[key] for item in load_jsonl(full_output_fname)] if not args.force else []
  156. data_list = [item for item in test_set if item[key] not in cache_question]
  157. logging.info(f'Left cases: {len(data_list)}')
  158. # inference
  159. writer_mode = 'w' if args.force else 'a'
  160. f_output = open(full_output_fname, writer_mode, encoding='utf-8')
  161. process_func = process_func_map.get(args.task, process_code_interpreter)
  162. sequential_processing(args, data_list, process_func, f_output)
  163. f_output.close()
  164. # evaluate
  165. if not args.gen_exec_only:
  166. eval_metrics(args, test_set, full_output_fname)
  167. os.chdir(current_dir)
  168. def parse_args():
  169. parser = argparse.ArgumentParser()
  170. parser.add_argument('--model', type=str, default='qwen-14b-chat', choices=list(model_path_map.keys()))
  171. parser.add_argument('--task', type=str, default='all', choices=['all', 'gsm8k', 'all_ci', 'visualization', 'math', 'general'])
  172. parser.add_argument('--output-path', type=str, default='output_data')
  173. parser.add_argument('--input-path', type=str, default='eval_data')
  174. parser.add_argument('-o', '--output-fname', type=str, default='')
  175. parser.add_argument('-i', '--input-fname', type=str, default='eval_code_interpreter_v1.jsonl')
  176. parser.add_argument('-f', '--force', action='store_true', default=False)
  177. parser.add_argument('--eval-only', action='store_true', default=False)
  178. parser.add_argument('--eval-code-exec-only', action='store_true', default=False)
  179. parser.add_argument('--gen-exec-only', action='store_true', default=False)
  180. parser.add_argument('--gen-only', action='store_true', default=False)
  181. args = parser.parse_args()
  182. return args
  183. if __name__ == '__main__':
  184. args = parse_args()
  185. if not args.eval_only:
  186. args.llm = get_model(args.model)
  187. logging.info(f'Init {args.model} done.')
  188. if args.task == 'all':
  189. for key in ['all_ci', 'gsm8k']:
  190. args.task = key
  191. main(args)
  192. else:
  193. main(args)
  194. gather_eval_result(args.model)