code_execution.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import logging
  2. import os
  3. import func_timeout
  4. from config import get_react_parser
  5. from func_timeout import func_set_timeout
  6. from utils.code_utils import extract_code, replace_upload_fname
  7. from utils.data_utils import load_jsonl, save_jsonl
  8. pre_load = """
  9. import os
  10. if 'upload_file' not in os.getcwd():
  11. os.chdir("./upload_file/")
  12. import seaborn as sns
  13. import matplotlib
  14. # matplotlib.use('Agg')
  15. import matplotlib.pyplot as plt
  16. plt.ion()
  17. import numpy as np
  18. import pandas as pd
  19. from sympy import Eq, symbols, solve
  20. import re
  21. import json
  22. import math
  23. """
  24. tags_config = {
  25. 'visualization': {
  26. 'timelimit': True,
  27. 'extract_first_code': True,
  28. },
  29. 'math': {
  30. 'timelimit': True,
  31. 'extract_first_code': False,
  32. },
  33. 'general': {
  34. 'timelimit': False,
  35. 'extract_first_code': True,
  36. }
  37. }
  38. code_executability = {
  39. 'math': None,
  40. 'visualization': None,
  41. 'general': None
  42. }
  43. @func_set_timeout(10)
  44. def exec_limit_time(text):
  45. exec(text, locals())
  46. def exec_code(text, timelimit=False):
  47. if timelimit:
  48. exec_limit_time(text)
  49. else:
  50. exec(text, locals())
  51. def postprocess_code(gen_code, line):
  52. if '<|im_start|>' in line['query']:
  53. first_action_code = get_action_input_code(line['query'])
  54. gen_code = first_action_code + gen_code
  55. upload_fname_list = line['input_file_path'] if line and 'input_file_path' in line else []
  56. gen_code = replace_upload_fname(gen_code, upload_fname_list)
  57. if 'def solution()' in gen_code:
  58. gen_code += '\nsolution()\n'
  59. if 'plt.show()' in gen_code:
  60. gen_code += "\nplt.pause(1)\nplt.close('all')\n"
  61. if 'sns.' in gen_code and 'plot' in gen_code:
  62. gen_code += "\nplt.close('all')\n"
  63. gen_code = pre_load + gen_code
  64. return gen_code
  65. def get_action_input_code(text, model_name='qwen-14b-chat', extract_first_code=False):
  66. action_input_list = []
  67. tmp = text
  68. react_parser = get_react_parser(model_name)
  69. while True:
  70. action_input = react_parser.get_first_action_input(tmp)
  71. if not action_input:
  72. break
  73. action_input_list.append(action_input)
  74. tmp = tmp.split(action_input)[1]
  75. if not tmp or extract_first_code:
  76. break
  77. code = ''
  78. for action_input in action_input_list:
  79. code = code + '# concat\n' + extract_code(action_input) + '\n'
  80. return code
  81. def eval_code_execution_rate(output_fname, tag='all_ci', model_name='qwen-14b-chat', timelimit=False, extract_first_code=False):
  82. data_list = load_jsonl(output_fname)
  83. pip_package = []
  84. for line_id, line in enumerate(data_list):
  85. line['idx'] = line_id
  86. tags_list = line['tags'].split(',')
  87. if tag not in tags_list:
  88. continue
  89. # update args
  90. for cur_tag in tags_list:
  91. if cur_tag != 'all_ci':
  92. timelimit = tags_config[cur_tag]['timelimit']
  93. extract_first_code = tags_config[cur_tag]['extract_first_code']
  94. line['executable_code'] = False
  95. line['missing_code'] = False
  96. line['code_error_info'] = ''
  97. # get Action Input code from response
  98. gen_code = get_action_input_code(line['gen'], model_name=model_name, extract_first_code=extract_first_code)
  99. if not gen_code:
  100. line['missing_code'] = True
  101. line['code'] = ''
  102. line['code_error_info'] = 'missing code'
  103. continue
  104. line['code'] = gen_code
  105. gen_code = postprocess_code(gen_code, line)
  106. while True:
  107. try:
  108. exec_code(gen_code, timelimit=timelimit)
  109. line['executable_code'] = True
  110. break
  111. except func_timeout.exceptions.FunctionTimedOut as ex:
  112. line['code_error_info'] = str(ex)
  113. break
  114. except (ImportError, ModuleNotFoundError) as ex:
  115. try:
  116. packege = str(ex).split("'")[1].strip()
  117. except Exception:
  118. packege = ''
  119. if packege and packege not in pip_package: # install package
  120. pip_package.append(packege)
  121. os.system('pip install '+packege)
  122. logging.info(f'Automatic installation: {packege}')
  123. else:
  124. line['code_error_info'] = str(ex)
  125. break
  126. except Exception as ex:
  127. line['code_error_info'] = str(ex)
  128. break
  129. # double check
  130. observation = get_react_parser(model_name).get_first_observation(line['gen'])
  131. if line['executable_code'] and ('error:' in observation):
  132. logging.warning('The code executes correctly, but it has an error in IPython!')
  133. logging.warning(f'Code:\n{gen_code}')
  134. logging.warning(f'IPython error info:\n{observation}')
  135. logging.info('='*60)
  136. elif not line['executable_code'] and not ('error:' in observation):
  137. logging.warning('The code has an execution error, but it runs correctly in IPython!')
  138. logging.warning(f'Code:\n{gen_code}')
  139. logging.warning(f"Exec error info:\n{line['code_error_info']}")
  140. logging.warning(f'IPython observation:\n{observation}')
  141. logging.info('='*60)
  142. # save error data
  143. error_data_list = [item for item in data_list if not item['executable_code'] or item['missing_code']]
  144. error_data_output_fname = os.path.splitext(output_fname)[0] + '_exec_error.jsonl'
  145. save_jsonl(error_data_list, error_data_output_fname)
  146. log_result(data_list)
  147. return code_executability
  148. def log_result(data_list, verbose=True):
  149. if verbose:
  150. logging.info('*'*60)
  151. logging.info('{:^60}'.format('Detail'))
  152. logging.info('*'*60)
  153. for line_id, line in enumerate(data_list):
  154. logging.info(f'Question {line_id}'.center(60, '='))
  155. logging.info(line['query'])
  156. logging.info(f'Generated {line_id}'.center(60, '-'))
  157. logging.info('\n' + line['gen'])
  158. logging.info(f'Code {line_id}'.center(60, '-'))
  159. logging.info('\n' + line['code'])
  160. logging.info(f'Exec Result {line_id}'.center(60, '-'))
  161. prefix_info = 'Exec Success' if line['executable_code'] else 'Exec Error: '
  162. exec_info = prefix_info + line['code_error_info']
  163. logging.info(exec_info)
  164. logging.info('='*60)
  165. logging.info('{:^60}'.format('Code Execuation Rate'))
  166. logging.info('='*60)
  167. involved_tags = []
  168. for line in data_list:
  169. involved_tags += line['tags'].split(',')
  170. involved_tags = list(set(involved_tags))
  171. for key in involved_tags:
  172. logging.info(f'task: {key}'.center(60, '='))
  173. key_item_list = [item for item in data_list if key in item['tags']]
  174. all_count = len(key_item_list)
  175. missing_code_count = len([item for item in key_item_list if item['missing_code']])
  176. executable_code_count = len([item for item in key_item_list if item['executable_code']])
  177. logging.info(f'All Test: {all_count}')
  178. logging.info(f'Missing Code: {missing_code_count}')
  179. logging.info(f'Predict Exec Success: {executable_code_count}')
  180. logging.info('Codes available && Execution Rate: {:.2f}'.format(executable_code_count/(all_count-missing_code_count)*100))
  181. logging.info('Execution Rate: {:.2f}'.format(executable_code_count/all_count*100))
  182. logging.info('Non-executable rate: {:.2f}'.format((all_count-missing_code_count-executable_code_count)/all_count*100))
  183. logging.info('Missing code rate: {:.2f}'.format(missing_code_count/all_count*100))
  184. if key != 'all_ci':
  185. code_executability[key] = executable_code_count/all_count*100
  186. if verbose:
  187. logging.info('Error List: ')
  188. error_list = [(item['idx'], item['code_error_info']) for item in key_item_list if item['code_error_info']]
  189. error_list.sort(key=lambda x: x[1])
  190. for x in error_list:
  191. logging.info(x)