code_interpreter.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import asyncio
  2. import atexit
  3. import base64
  4. import glob
  5. import io
  6. import json
  7. import os
  8. import queue
  9. import re
  10. import shutil
  11. import signal
  12. import stat
  13. import subprocess
  14. import sys
  15. import time
  16. import uuid
  17. from pathlib import Path
  18. from typing import Dict, List, Optional, Union
  19. import json5
  20. from qwen_agent.log import logger
  21. from qwen_agent.tools.base import BaseToolWithFileAccess, register_tool
  22. from qwen_agent.utils.utils import append_signal_handler, extract_code, has_chinese_chars, print_traceback
  23. LAUNCH_KERNEL_PY = """
  24. from ipykernel import kernelapp as app
  25. app.launch_new_instance()
  26. """
  27. INIT_CODE_FILE = str(Path(__file__).absolute().parent / 'resource' / 'code_interpreter_init_kernel.py')
  28. ALIB_FONT_FILE = str(Path(__file__).absolute().parent / 'resource' / 'AlibabaPuHuiTi-3-45-Light.ttf')
  29. _KERNEL_CLIENTS: dict = {}
  30. _MISC_SUBPROCESSES: Dict[str, subprocess.Popen] = {}
  31. def _kill_kernels_and_subprocesses(_sig_num=None, _frame=None):
  32. for v in _KERNEL_CLIENTS.values():
  33. v.shutdown()
  34. for k in list(_KERNEL_CLIENTS.keys()):
  35. del _KERNEL_CLIENTS[k]
  36. for v in _MISC_SUBPROCESSES.values():
  37. v.terminate()
  38. for k in list(_MISC_SUBPROCESSES.keys()):
  39. del _MISC_SUBPROCESSES[k]
  40. # Make sure all subprocesses are terminated even if killed abnormally:
  41. atexit.register(_kill_kernels_and_subprocesses)
  42. append_signal_handler(signal.SIGTERM, _kill_kernels_and_subprocesses)
  43. append_signal_handler(signal.SIGINT, _kill_kernels_and_subprocesses)
  44. @register_tool('code_interpreter')
  45. class CodeInterpreter(BaseToolWithFileAccess):
  46. description = 'Python代码沙盒,可用于执行Python代码。'
  47. parameters = [{'name': 'code', 'type': 'string', 'description': '待执行的代码', 'required': True}]
  48. def __init__(self, cfg: Optional[Dict] = None):
  49. super().__init__(cfg)
  50. self.work_dir: str = os.getenv('M6_CODE_INTERPRETER_WORK_DIR', self.work_dir)
  51. self.work_dir: str = self.cfg.get('work_dir', self.work_dir)
  52. self.instance_id: str = str(uuid.uuid4())
  53. _check_deps_for_code_interpreter()
  54. @property
  55. def args_format(self) -> str:
  56. fmt = self.cfg.get('args_format')
  57. if fmt is None:
  58. if has_chinese_chars([self.name_for_human, self.name, self.description, self.parameters]):
  59. fmt = '此工具的输入应为Markdown代码块。'
  60. else:
  61. fmt = 'Enclose the code within triple backticks (`) at the beginning and end of the code.'
  62. return fmt
  63. def call(self, params: Union[str, dict], files: List[str] = None, timeout: Optional[int] = 30, **kwargs) -> str:
  64. super().call(params=params, files=files) # copy remote files to work_dir
  65. try:
  66. params = json5.loads(params)
  67. code = params['code']
  68. except Exception:
  69. code = extract_code(params)
  70. if not code.strip():
  71. return ''
  72. kernel_id: str = f'{self.instance_id}_{os.getpid()}'
  73. if kernel_id in _KERNEL_CLIENTS:
  74. kc = _KERNEL_CLIENTS[kernel_id]
  75. else:
  76. _fix_matplotlib_cjk_font_issue()
  77. self._fix_secure_write_for_code_interpreter()
  78. kc, subproc = self._start_kernel(kernel_id)
  79. with open(INIT_CODE_FILE) as fin:
  80. start_code = fin.read()
  81. start_code = start_code.replace('{{M6_FONT_PATH}}', repr(ALIB_FONT_FILE)[1:-1])
  82. start_code += '\n%xmode Minimal'
  83. logger.info(self._execute_code(kc, start_code))
  84. _KERNEL_CLIENTS[kernel_id] = kc
  85. _MISC_SUBPROCESSES[kernel_id] = subproc
  86. if timeout:
  87. code = f'_M6CountdownTimer.start({timeout})\n{code}'
  88. fixed_code = []
  89. for line in code.split('\n'):
  90. fixed_code.append(line)
  91. if line.startswith('sns.set_theme('):
  92. fixed_code.append('plt.rcParams["font.family"] = _m6_font_prop.get_name()')
  93. fixed_code = '\n'.join(fixed_code)
  94. fixed_code += '\n\n' # Prevent code not executing in notebook due to no line breaks at the end
  95. result = self._execute_code(kc, fixed_code)
  96. if timeout:
  97. self._execute_code(kc, '_M6CountdownTimer.cancel()')
  98. return result if result.strip() else 'Finished execution.'
  99. def __del__(self):
  100. # Recycle the jupyter subprocess:
  101. k: str = f'{self.instance_id}_{os.getpid()}'
  102. if k in _KERNEL_CLIENTS:
  103. _KERNEL_CLIENTS[k].shutdown()
  104. del _KERNEL_CLIENTS[k]
  105. if k in _MISC_SUBPROCESSES:
  106. _MISC_SUBPROCESSES[k].terminate()
  107. del _MISC_SUBPROCESSES[k]
  108. def _fix_secure_write_for_code_interpreter(self):
  109. if 'linux' in sys.platform.lower():
  110. os.makedirs(self.work_dir, exist_ok=True)
  111. fname = os.path.join(self.work_dir, f'test_file_permission_{os.getpid()}.txt')
  112. if os.path.exists(fname):
  113. os.remove(fname)
  114. with os.fdopen(os.open(fname, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o0600), 'w') as f:
  115. f.write('test')
  116. file_mode = stat.S_IMODE(os.stat(fname).st_mode) & 0o6677
  117. if file_mode != 0o0600:
  118. os.environ['JUPYTER_ALLOW_INSECURE_WRITES'] = '1'
  119. if os.path.exists(fname):
  120. os.remove(fname)
  121. def _start_kernel(self, kernel_id: str):
  122. connection_file = os.path.join(self.work_dir, f'kernel_connection_file_{kernel_id}.json')
  123. launch_kernel_script = os.path.join(self.work_dir, f'launch_kernel_{kernel_id}.py')
  124. for f in [connection_file, launch_kernel_script]:
  125. if os.path.exists(f):
  126. logger.info(f'WARNING: {f} already exists')
  127. os.remove(f)
  128. os.makedirs(self.work_dir, exist_ok=True)
  129. with open(launch_kernel_script, 'w') as fout:
  130. fout.write(LAUNCH_KERNEL_PY)
  131. kernel_process = subprocess.Popen(
  132. [
  133. sys.executable,
  134. os.path.abspath(launch_kernel_script),
  135. '--IPKernelApp.connection_file',
  136. os.path.abspath(connection_file),
  137. '--matplotlib=inline',
  138. '--quiet',
  139. ],
  140. cwd=os.path.abspath(self.work_dir),
  141. )
  142. logger.info(f"INFO: kernel process's PID = {kernel_process.pid}")
  143. # Wait for kernel connection file to be written
  144. while True:
  145. if not os.path.isfile(connection_file):
  146. time.sleep(0.1)
  147. else:
  148. # Keep looping if JSON parsing fails, file may be partially written
  149. try:
  150. with open(connection_file, 'r') as fp:
  151. json.load(fp)
  152. break
  153. except json.JSONDecodeError:
  154. pass
  155. # Client
  156. from jupyter_client import BlockingKernelClient
  157. kc = BlockingKernelClient(connection_file=connection_file)
  158. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  159. kc.load_connection_file()
  160. kc.start_channels()
  161. kc.wait_for_ready()
  162. return kc, kernel_process
  163. def _execute_code(self, kc, code: str) -> str:
  164. kc.wait_for_ready()
  165. kc.execute(code)
  166. result = ''
  167. image_idx = 0
  168. while True:
  169. text = ''
  170. image = ''
  171. finished = False
  172. msg_type = 'error'
  173. try:
  174. msg = kc.get_iopub_msg()
  175. msg_type = msg['msg_type']
  176. if msg_type == 'status':
  177. if msg['content'].get('execution_state') == 'idle':
  178. finished = True
  179. elif msg_type == 'execute_result':
  180. text = msg['content']['data'].get('text/plain', '')
  181. if 'image/png' in msg['content']['data']:
  182. image_b64 = msg['content']['data']['image/png']
  183. image_url = self._serve_image(image_b64)
  184. image_idx += 1
  185. image = '![fig-%03d](%s)' % (image_idx, image_url)
  186. elif msg_type == 'display_data':
  187. if 'image/png' in msg['content']['data']:
  188. image_b64 = msg['content']['data']['image/png']
  189. image_url = self._serve_image(image_b64)
  190. image_idx += 1
  191. image = '![fig-%03d](%s)' % (image_idx, image_url)
  192. else:
  193. text = msg['content']['data'].get('text/plain', '')
  194. elif msg_type == 'stream':
  195. msg_type = msg['content']['name'] # stdout, stderr
  196. text = msg['content']['text']
  197. elif msg_type == 'error':
  198. text = _escape_ansi('\n'.join(msg['content']['traceback']))
  199. if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
  200. text = 'Timeout: Code execution exceeded the time limit.'
  201. except queue.Empty:
  202. text = 'Timeout: Code execution exceeded the time limit.'
  203. finished = True
  204. except Exception:
  205. text = 'The code interpreter encountered an unexpected error.'
  206. print_traceback()
  207. finished = True
  208. if text:
  209. result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
  210. if image:
  211. result += f'\n\n{image}'
  212. if finished:
  213. break
  214. result = result.lstrip('\n')
  215. return result
  216. def _serve_image(self, image_base64: str) -> str:
  217. import PIL.Image
  218. image_file = f'{uuid.uuid4()}.png'
  219. local_image_file = os.path.join(self.work_dir, image_file)
  220. png_bytes = base64.b64decode(image_base64)
  221. assert isinstance(png_bytes, bytes)
  222. bytes_io = io.BytesIO(png_bytes)
  223. PIL.Image.open(bytes_io).save(local_image_file, 'png')
  224. image_server_url = os.getenv('M6_CODE_INTERPRETER_STATIC_URL', '')
  225. if image_server_url:
  226. return f'{image_server_url}/{image_file}'
  227. return local_image_file
  228. def _check_deps_for_code_interpreter():
  229. try:
  230. import matplotlib # noqa
  231. import matplotlib.pyplot as plt # noqa
  232. import numpy as np # noqa
  233. import pandas as pd # noqa
  234. import PIL.Image # noqa
  235. import seaborn as sns # noqa
  236. from jupyter_client import BlockingKernelClient # noqa
  237. from sympy import Eq, solve, symbols # noqa
  238. except ImportError as e:
  239. raise ImportError(
  240. 'The dependencies for Code Interpreter support are not installed. '
  241. 'Please install the required dependencies by running: pip install qwen-agent[code_interpreter]') from e
  242. def _fix_matplotlib_cjk_font_issue():
  243. import matplotlib
  244. ttf_name = os.path.basename(ALIB_FONT_FILE)
  245. local_ttf = os.path.join(os.path.abspath(os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)), 'fonts',
  246. 'ttf', ttf_name)
  247. if not os.path.exists(local_ttf):
  248. try:
  249. shutil.copy(ALIB_FONT_FILE, local_ttf)
  250. font_list_cache = os.path.join(matplotlib.get_cachedir(), 'fontlist-*.json')
  251. for cache_file in glob.glob(font_list_cache):
  252. with open(cache_file) as fin:
  253. cache_content = fin.read()
  254. if ttf_name not in cache_content:
  255. os.remove(cache_file)
  256. except Exception:
  257. print_traceback()
  258. def _escape_ansi(line: str) -> str:
  259. ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
  260. return ansi_escape.sub('', line)
  261. #
  262. # The _BasePolicy and AnyThreadEventLoopPolicy below are borrowed from Tornado.
  263. # Ref: https://www.tornadoweb.org/en/stable/_modules/tornado/platform/asyncio.html#AnyThreadEventLoopPolicy
  264. #
  265. if sys.platform == 'win32' and hasattr(asyncio, 'WindowsSelectorEventLoopPolicy'):
  266. _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
  267. else:
  268. _BasePolicy = asyncio.DefaultEventLoopPolicy
  269. class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
  270. """Event loop policy that allows loop creation on any thread.
  271. The default `asyncio` event loop policy only automatically creates
  272. event loops in the main threads. Other threads must create event
  273. loops explicitly or `asyncio.get_event_loop` (and therefore
  274. `.IOLoop.current`) will fail. Installing this policy allows event
  275. loops to be created automatically on any thread.
  276. Usage::
  277. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  278. """
  279. def get_event_loop(self) -> asyncio.AbstractEventLoop:
  280. try:
  281. return super().get_event_loop()
  282. except RuntimeError:
  283. # "There is no current event loop in thread %r"
  284. loop = self.new_event_loop()
  285. self.set_event_loop(loop)
  286. return loop