code_interpreter.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import base64
  2. import io
  3. import json
  4. import logging
  5. import os
  6. import queue
  7. import re
  8. import subprocess
  9. import sys
  10. import time
  11. import traceback
  12. import uuid
  13. import matplotlib
  14. import PIL.Image
  15. from jupyter_client import BlockingKernelClient
  16. from utils.code_utils import extract_code
  17. WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace')
  18. LAUNCH_KERNEL_PY = """
  19. from ipykernel import kernelapp as app
  20. app.launch_new_instance()
  21. """
  22. _KERNEL_CLIENTS = {}
  23. # Run this fix before jupyter starts if matplotlib cannot render CJK fonts.
  24. # And we need to additionally run the following lines in the jupyter notebook.
  25. # ```python
  26. # import matplotlib.pyplot as plt
  27. # plt.rcParams['font.sans-serif'] = ['SimHei']
  28. # plt.rcParams['axes.unicode_minus'] = False
  29. # ````
  30. def fix_matplotlib_cjk_font_issue():
  31. local_ttf = os.path.join(
  32. os.path.abspath(
  33. os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
  34. 'fonts', 'ttf', 'simhei.ttf')
  35. if not os.path.exists(local_ttf):
  36. logging.warning(f'Missing font file `{local_ttf}` for matplotlib. It may cause some error when using matplotlib.')
  37. def start_kernel(pid):
  38. fix_matplotlib_cjk_font_issue()
  39. connection_file = os.path.join(WORK_DIR,
  40. f'kernel_connection_file_{pid}.json')
  41. launch_kernel_script = os.path.join(WORK_DIR, f'launch_kernel_{pid}.py')
  42. for f in [connection_file, launch_kernel_script]:
  43. if os.path.exists(f):
  44. logging.warning(f'{f} already exists')
  45. os.remove(f)
  46. os.makedirs(WORK_DIR, exist_ok=True)
  47. with open(launch_kernel_script, 'w') as fout:
  48. fout.write(LAUNCH_KERNEL_PY)
  49. kernel_process = subprocess.Popen([
  50. sys.executable,
  51. launch_kernel_script,
  52. '--IPKernelApp.connection_file',
  53. connection_file,
  54. '--matplotlib=inline',
  55. '--quiet',
  56. ],
  57. cwd=WORK_DIR)
  58. logging.info(f"INFO: kernel process's PID = {kernel_process.pid}")
  59. # Wait for kernel connection file to be written
  60. while True:
  61. if not os.path.isfile(connection_file):
  62. time.sleep(0.1)
  63. else:
  64. # Keep looping if JSON parsing fails, file may be partially written
  65. try:
  66. with open(connection_file, 'r') as fp:
  67. json.load(fp)
  68. break
  69. except json.JSONDecodeError:
  70. pass
  71. # Client
  72. kc = BlockingKernelClient(connection_file=connection_file)
  73. kc.load_connection_file()
  74. kc.start_channels()
  75. kc.wait_for_ready()
  76. return kc
  77. def escape_ansi(line):
  78. ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
  79. return ansi_escape.sub('', line)
  80. def publish_image_to_local(image_base64: str):
  81. image_file = str(uuid.uuid4()) + '.png'
  82. local_image_file = os.path.join(WORK_DIR, image_file)
  83. png_bytes = base64.b64decode(image_base64)
  84. assert isinstance(png_bytes, bytes)
  85. bytes_io = io.BytesIO(png_bytes)
  86. PIL.Image.open(bytes_io).save(local_image_file, 'png')
  87. return local_image_file
  88. START_CODE = """
  89. import signal
  90. def _m6_code_interpreter_timeout_handler(signum, frame):
  91. raise TimeoutError("M6_CODE_INTERPRETER_TIMEOUT")
  92. signal.signal(signal.SIGALRM, _m6_code_interpreter_timeout_handler)
  93. def input(*args, **kwargs):
  94. raise NotImplementedError('Python input() function is disabled.')
  95. import os
  96. if 'upload_file' not in os.getcwd():
  97. os.chdir("./upload_file/")
  98. import math
  99. import re
  100. import json
  101. import seaborn as sns
  102. sns.set_theme()
  103. import matplotlib
  104. import matplotlib.pyplot as plt
  105. plt.rcParams['font.sans-serif'] = ['SimHei']
  106. plt.rcParams['axes.unicode_minus'] = False
  107. import numpy as np
  108. import pandas as pd
  109. from sympy import Eq, symbols, solve
  110. """
  111. def code_interpreter(action_input_list: list, timeout=30, clear=False):
  112. code = ''
  113. for action_input in action_input_list:
  114. code += (extract_code(action_input) + '\n')
  115. fixed_code = []
  116. for line in code.split('\n'):
  117. fixed_code.append(line)
  118. if line.startswith('sns.set_theme('):
  119. fixed_code.append('plt.rcParams["font.sans-serif"] = ["SimHei"]')
  120. fixed_code.append('plt.rcParams["axes.unicode_minus"] = False')
  121. fixed_code = '\n'.join(fixed_code)
  122. if 'def solution()' in fixed_code:
  123. fixed_code += '\nsolution()'
  124. return _code_interpreter(fixed_code, timeout, clear)
  125. def _code_interpreter(code: str, timeout, clear=False):
  126. if not code.strip():
  127. return ''
  128. if timeout:
  129. code = f'signal.alarm({timeout})\n{code}'
  130. if clear:
  131. code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE + code
  132. pid = os.getpid()
  133. if pid not in _KERNEL_CLIENTS:
  134. _KERNEL_CLIENTS[pid] = start_kernel(pid)
  135. _code_interpreter(START_CODE, timeout=None)
  136. kc = _KERNEL_CLIENTS[pid]
  137. kc.wait_for_ready()
  138. kc.execute(code)
  139. result = ''
  140. image_idx = 0
  141. while True:
  142. text = ''
  143. image = ''
  144. finished = False
  145. msg_type = 'error'
  146. try:
  147. msg = kc.get_iopub_msg()
  148. msg_type = msg['msg_type']
  149. if msg_type == 'status':
  150. if msg['content'].get('execution_state') == 'idle':
  151. finished = True
  152. elif msg_type == 'execute_result':
  153. text = msg['content']['data'].get('text/plain', '')
  154. if 'image/png' in msg['content']['data']:
  155. image_b64 = msg['content']['data']['image/png']
  156. image_url = publish_image_to_local(image_b64)
  157. image_idx += 1
  158. image = '![fig-%03d](%s)' % (image_idx, image_url)
  159. elif msg_type == 'display_data':
  160. if 'image/png' in msg['content']['data']:
  161. image_b64 = msg['content']['data']['image/png']
  162. image_url = publish_image_to_local(image_b64)
  163. image_idx += 1
  164. image = '![fig-%03d](%s)' % (image_idx, image_url)
  165. else:
  166. text = msg['content']['data'].get('text/plain', '')
  167. elif msg_type == 'stream':
  168. msg_type = msg['content']['name'] # stdout, stderr
  169. text = msg['content']['text']
  170. elif msg_type == 'error':
  171. text = escape_ansi('\n'.join(msg['content']['traceback']))
  172. if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
  173. text = f'Timeout. No response after {timeout} seconds.'
  174. except queue.Empty:
  175. text = f'Timeout. No response after {timeout} seconds.'
  176. finished = True
  177. except Exception:
  178. text = 'The code interpreter encountered an unexpected error.'
  179. logging.warning(''.join(traceback.format_exception(*sys.exc_info())))
  180. finished = True
  181. if text:
  182. result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
  183. if image:
  184. result += f'\n\n{image}'
  185. if finished:
  186. break
  187. result = result.lstrip('\n')
  188. if timeout:
  189. _code_interpreter('signal.alarm(0)', timeout=None)
  190. return result
  191. def get_multiline_input(hint):
  192. print(hint)
  193. print('// Press ENTER to make a new line. Press CTRL-D to end input.')
  194. lines = []
  195. while True:
  196. try:
  197. line = input()
  198. except EOFError: # CTRL-D
  199. break
  200. lines.append(line)
  201. print('// Input received.')
  202. if lines:
  203. return '\n'.join(lines)
  204. else:
  205. return ''
  206. if __name__ == '__main__':
  207. while True:
  208. print(code_interpreter([get_multiline_input('Enter python code:')]))