123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- import base64
- import io
- import json
- import logging
- import os
- import queue
- import re
- import subprocess
- import sys
- import time
- import traceback
- import uuid
- import matplotlib
- import PIL.Image
- from jupyter_client import BlockingKernelClient
- from utils.code_utils import extract_code
- WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace')
- LAUNCH_KERNEL_PY = """
- from ipykernel import kernelapp as app
- app.launch_new_instance()
- """
- _KERNEL_CLIENTS = {}
- # Run this fix before jupyter starts if matplotlib cannot render CJK fonts.
- # And we need to additionally run the following lines in the jupyter notebook.
- # ```python
- # import matplotlib.pyplot as plt
- # plt.rcParams['font.sans-serif'] = ['SimHei']
- # plt.rcParams['axes.unicode_minus'] = False
- # ````
- def fix_matplotlib_cjk_font_issue():
- local_ttf = os.path.join(
- os.path.abspath(
- os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
- 'fonts', 'ttf', 'simhei.ttf')
- if not os.path.exists(local_ttf):
- logging.warning(f'Missing font file `{local_ttf}` for matplotlib. It may cause some error when using matplotlib.')
- def start_kernel(pid):
- fix_matplotlib_cjk_font_issue()
- connection_file = os.path.join(WORK_DIR,
- f'kernel_connection_file_{pid}.json')
- launch_kernel_script = os.path.join(WORK_DIR, f'launch_kernel_{pid}.py')
- for f in [connection_file, launch_kernel_script]:
- if os.path.exists(f):
- logging.warning(f'{f} already exists')
- os.remove(f)
- os.makedirs(WORK_DIR, exist_ok=True)
- with open(launch_kernel_script, 'w') as fout:
- fout.write(LAUNCH_KERNEL_PY)
- kernel_process = subprocess.Popen([
- sys.executable,
- launch_kernel_script,
- '--IPKernelApp.connection_file',
- connection_file,
- '--matplotlib=inline',
- '--quiet',
- ],
- cwd=WORK_DIR)
- logging.info(f"INFO: kernel process's PID = {kernel_process.pid}")
- # Wait for kernel connection file to be written
- while True:
- if not os.path.isfile(connection_file):
- time.sleep(0.1)
- else:
- # Keep looping if JSON parsing fails, file may be partially written
- try:
- with open(connection_file, 'r') as fp:
- json.load(fp)
- break
- except json.JSONDecodeError:
- pass
- # Client
- kc = BlockingKernelClient(connection_file=connection_file)
- kc.load_connection_file()
- kc.start_channels()
- kc.wait_for_ready()
- return kc
- def escape_ansi(line):
- ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
- return ansi_escape.sub('', line)
- def publish_image_to_local(image_base64: str):
- image_file = str(uuid.uuid4()) + '.png'
- local_image_file = os.path.join(WORK_DIR, image_file)
- png_bytes = base64.b64decode(image_base64)
- assert isinstance(png_bytes, bytes)
- bytes_io = io.BytesIO(png_bytes)
- PIL.Image.open(bytes_io).save(local_image_file, 'png')
- return local_image_file
- START_CODE = """
- import signal
- def _m6_code_interpreter_timeout_handler(signum, frame):
- raise TimeoutError("M6_CODE_INTERPRETER_TIMEOUT")
- signal.signal(signal.SIGALRM, _m6_code_interpreter_timeout_handler)
- def input(*args, **kwargs):
- raise NotImplementedError('Python input() function is disabled.')
- import os
- if 'upload_file' not in os.getcwd():
- os.chdir("./upload_file/")
- import math
- import re
- import json
- import seaborn as sns
- sns.set_theme()
- import matplotlib
- import matplotlib.pyplot as plt
- plt.rcParams['font.sans-serif'] = ['SimHei']
- plt.rcParams['axes.unicode_minus'] = False
- import numpy as np
- import pandas as pd
- from sympy import Eq, symbols, solve
- """
- def code_interpreter(action_input_list: list, timeout=30, clear=False):
- code = ''
- for action_input in action_input_list:
- code += (extract_code(action_input) + '\n')
- fixed_code = []
- for line in code.split('\n'):
- fixed_code.append(line)
- if line.startswith('sns.set_theme('):
- fixed_code.append('plt.rcParams["font.sans-serif"] = ["SimHei"]')
- fixed_code.append('plt.rcParams["axes.unicode_minus"] = False')
- fixed_code = '\n'.join(fixed_code)
- if 'def solution()' in fixed_code:
- fixed_code += '\nsolution()'
- return _code_interpreter(fixed_code, timeout, clear)
- def _code_interpreter(code: str, timeout, clear=False):
- if not code.strip():
- return ''
- if timeout:
- code = f'signal.alarm({timeout})\n{code}'
- if clear:
- code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE + code
- pid = os.getpid()
- if pid not in _KERNEL_CLIENTS:
- _KERNEL_CLIENTS[pid] = start_kernel(pid)
- _code_interpreter(START_CODE, timeout=None)
- kc = _KERNEL_CLIENTS[pid]
- kc.wait_for_ready()
- kc.execute(code)
- result = ''
- image_idx = 0
- while True:
- text = ''
- image = ''
- finished = False
- msg_type = 'error'
- try:
- msg = kc.get_iopub_msg()
- msg_type = msg['msg_type']
- if msg_type == 'status':
- if msg['content'].get('execution_state') == 'idle':
- finished = True
- elif msg_type == 'execute_result':
- text = msg['content']['data'].get('text/plain', '')
- if 'image/png' in msg['content']['data']:
- image_b64 = msg['content']['data']['image/png']
- image_url = publish_image_to_local(image_b64)
- image_idx += 1
- image = '' % (image_idx, image_url)
- elif msg_type == 'display_data':
- if 'image/png' in msg['content']['data']:
- image_b64 = msg['content']['data']['image/png']
- image_url = publish_image_to_local(image_b64)
- image_idx += 1
- image = '' % (image_idx, image_url)
- else:
- text = msg['content']['data'].get('text/plain', '')
- elif msg_type == 'stream':
- msg_type = msg['content']['name'] # stdout, stderr
- text = msg['content']['text']
- elif msg_type == 'error':
- text = escape_ansi('\n'.join(msg['content']['traceback']))
- if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
- text = f'Timeout. No response after {timeout} seconds.'
- except queue.Empty:
- text = f'Timeout. No response after {timeout} seconds.'
- finished = True
- except Exception:
- text = 'The code interpreter encountered an unexpected error.'
- logging.warning(''.join(traceback.format_exception(*sys.exc_info())))
- finished = True
- if text:
- result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
- if image:
- result += f'\n\n{image}'
- if finished:
- break
- result = result.lstrip('\n')
- if timeout:
- _code_interpreter('signal.alarm(0)', timeout=None)
- return result
- def get_multiline_input(hint):
- print(hint)
- print('// Press ENTER to make a new line. Press CTRL-D to end input.')
- lines = []
- while True:
- try:
- line = input()
- except EOFError: # CTRL-D
- break
- lines.append(line)
- print('// Input received.')
- if lines:
- return '\n'.join(lines)
- else:
- return ''
- if __name__ == '__main__':
- while True:
- print(code_interpreter([get_multiline_input('Enter python code:')]))
|