123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507 |
- import base64
- import copy
- import hashlib
- import json
- import os
- import re
- import shutil
- import signal
- import socket
- import sys
- import time
- import traceback
- import urllib.parse
- from io import BytesIO
- from typing import Any, List, Literal, Optional, Tuple, Union
- import json5
- import requests
- from qwen_agent.llm.schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, FUNCTION, SYSTEM, USER, ContentItem, Message
- from qwen_agent.log import logger
- def append_signal_handler(sig, handler):
- """
- Installs a new signal handler while preserving any existing handler.
- If an existing handler is present, it will be called _after_ the new handler.
- """
- old_handler = signal.getsignal(sig)
- if not callable(old_handler):
- old_handler = None
- if sig == signal.SIGINT:
- def old_handler(*args, **kwargs):
- raise KeyboardInterrupt
- elif sig == signal.SIGTERM:
- def old_handler(*args, **kwargs):
- raise SystemExit
- def new_handler(*args, **kwargs):
- handler(*args, **kwargs)
- if old_handler is not None:
- old_handler(*args, **kwargs)
- signal.signal(sig, new_handler)
- def get_local_ip() -> str:
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- try:
- # doesn't even have to be reachable
- s.connect(('10.255.255.255', 1))
- ip = s.getsockname()[0]
- except Exception:
- ip = '127.0.0.1'
- finally:
- s.close()
- return ip
- def hash_sha256(text: str) -> str:
- hash_object = hashlib.sha256(text.encode())
- key = hash_object.hexdigest()
- return key
- def print_traceback(is_error: bool = True):
- tb = ''.join(traceback.format_exception(*sys.exc_info(), limit=3))
- if is_error:
- logger.error(tb)
- else:
- logger.warning(tb)
- CHINESE_CHAR_RE = re.compile(r'[\u4e00-\u9fff]')
- def has_chinese_chars(data: Any) -> bool:
- text = f'{data}'
- return bool(CHINESE_CHAR_RE.search(text))
- def has_chinese_messages(messages: List[Union[Message, dict]], check_roles: Tuple[str] = (SYSTEM, USER)) -> bool:
- for m in messages:
- if m['role'] in check_roles:
- if has_chinese_chars(m['content']):
- return True
- return False
- def get_basename_from_url(path_or_url: str) -> str:
- if re.match(r'^[A-Za-z]:\\', path_or_url):
- # "C:\\a\\b\\c" -> "C:/a/b/c"
- path_or_url = path_or_url.replace('\\', '/')
- # "/mnt/a/b/c" -> "c"
- # "https://github.com/here?k=v" -> "here"
- # "https://github.com/" -> ""
- basename = urllib.parse.urlparse(path_or_url).path
- basename = os.path.basename(basename)
- basename = urllib.parse.unquote(basename)
- basename = basename.strip()
- # "https://github.com/" -> "" -> "github.com"
- if not basename:
- basename = [x.strip() for x in path_or_url.split('/') if x.strip()][-1]
- return basename
- def is_http_url(path_or_url: str) -> bool:
- if path_or_url.startswith('https://') or path_or_url.startswith('http://'):
- return True
- return False
- def is_image(path_or_url: str) -> bool:
- filename = get_basename_from_url(path_or_url).lower()
- for ext in ['jpg', 'jpeg', 'png', 'webp']:
- if filename.endswith(ext):
- return True
- return False
- def sanitize_chrome_file_path(file_path: str) -> str:
- if os.path.exists(file_path):
- return file_path
- # Dealing with "file:///...":
- new_path = urllib.parse.urlparse(file_path)
- new_path = urllib.parse.unquote(new_path.path)
- new_path = sanitize_windows_file_path(new_path)
- if os.path.exists(new_path):
- return new_path
- return sanitize_windows_file_path(file_path)
- def sanitize_windows_file_path(file_path: str) -> str:
- # For Linux and macOS.
- if os.path.exists(file_path):
- return file_path
- # For native Windows, drop the leading '/' in '/C:/'
- win_path = file_path
- if win_path.startswith('/'):
- win_path = win_path[1:]
- if os.path.exists(win_path):
- return win_path
- # For Windows + WSL.
- if re.match(r'^[A-Za-z]:/', win_path):
- wsl_path = f'/mnt/{win_path[0].lower()}/{win_path[3:]}'
- if os.path.exists(wsl_path):
- return wsl_path
- # For native Windows, replace / with \.
- win_path = win_path.replace('/', '\\')
- if os.path.exists(win_path):
- return win_path
- return file_path
- def save_url_to_local_work_dir(url: str, save_dir: str, save_filename: str = '') -> str:
- if not save_filename:
- save_filename = get_basename_from_url(url)
- new_path = os.path.join(save_dir, save_filename)
- if os.path.exists(new_path):
- os.remove(new_path)
- logger.info(f'Downloading {url} to {new_path}...')
- start_time = time.time()
- if not is_http_url(url):
- url = sanitize_chrome_file_path(url)
- shutil.copy(url, new_path)
- else:
- headers = {
- 'User-Agent':
- 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
- }
- response = requests.get(url, headers=headers)
- if response.status_code == 200:
- with open(new_path, 'wb') as file:
- file.write(response.content)
- else:
- raise ValueError('Can not download this file. Please check your network or the file link.')
- end_time = time.time()
- logger.info(f'Finished downloading {url} to {new_path}. Time spent: {end_time - start_time} seconds.')
- return new_path
- def save_text_to_file(path: str, text: str) -> None:
- with open(path, 'w', encoding='utf-8') as fp:
- fp.write(text)
- def read_text_from_file(path: str) -> str:
- try:
- with open(path, 'r', encoding='utf-8') as file:
- file_content = file.read()
- except UnicodeDecodeError:
- print_traceback(is_error=False)
- from charset_normalizer import from_path
- results = from_path(path)
- file_content = str(results.best())
- return file_content
- def contains_html_tags(text: str) -> bool:
- pattern = r'<(p|span|div|li|html|script)[^>]*?'
- return bool(re.search(pattern, text))
- def get_content_type_by_head_request(path: str) -> str:
- try:
- response = requests.head(path, timeout=5)
- content_type = response.headers.get('Content-Type', '')
- return content_type
- except requests.RequestException:
- return 'unk'
- def get_file_type(path: str) -> Literal['pdf', 'docx', 'pptx', 'txt', 'html', 'csv', 'tsv', 'xlsx', 'xls', 'unk']:
- f_type = get_basename_from_url(path).split('.')[-1].lower()
- if f_type in ['pdf', 'docx', 'pptx', 'csv', 'tsv', 'xlsx', 'xls']:
- # Specially supported file types
- return f_type
- if is_http_url(path):
- # The HTTP header information for the response is obtained by making a HEAD request to the target URL,
- # where the Content-type field usually indicates the Type of Content to be returned
- content_type = get_content_type_by_head_request(path)
- if 'application/pdf' in content_type:
- return 'pdf'
- elif 'application/msword' in content_type:
- return 'docx'
- # Assuming that the URL is HTML by default,
- # because the file downloaded by the request may contain html tags
- return 'html'
- else:
- # Determine by reading local HTML file
- try:
- content = read_text_from_file(path)
- except Exception:
- print_traceback()
- return 'unk'
- if contains_html_tags(content):
- return 'html'
- else:
- return 'txt'
- def extract_urls(text: str) -> List[str]:
- pattern = re.compile(r'https?://\S+')
- urls = re.findall(pattern, text)
- return urls
- def extract_markdown_urls(md_text: str) -> List[str]:
- pattern = r'!?\[[^\]]*\]\(([^\)]+)\)'
- urls = re.findall(pattern, md_text)
- return urls
- def extract_code(text: str) -> str:
- # Match triple backtick blocks first
- triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
- if triple_match:
- text = triple_match.group(1)
- else:
- try:
- text = json5.loads(text)['code']
- except Exception:
- print_traceback(is_error=False)
- # If no code blocks found, return original text
- return text
- def json_loads(text: str) -> dict:
- text = text.strip('\n')
- if text.startswith('```') and text.endswith('\n```'):
- text = '\n'.join(text.split('\n')[1:-1])
- try:
- return json.loads(text)
- except json.decoder.JSONDecodeError as json_err:
- try:
- return json5.loads(text)
- except ValueError:
- raise json_err
- def json_dumps(obj: dict) -> str:
- return json.dumps(obj, ensure_ascii=False, indent=2)
- def format_as_multimodal_message(
- msg: Message,
- add_upload_info: bool,
- lang: Literal['auto', 'en', 'zh'] = 'auto',
- ) -> Message:
- assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION)
- content: List[ContentItem] = []
- if isinstance(msg.content, str): # if text content
- if msg.content:
- content = [ContentItem(text=msg.content)]
- elif isinstance(msg.content, list): # if multimodal content
- files = []
- for item in msg.content:
- k, v = item.get_type_and_value()
- if k == 'text':
- content.append(ContentItem(text=v))
- if k == 'image':
- content.append(item)
- if k in ('file', 'image'):
- # Move 'file' out of 'content' since it's not natively supported by models
- files.append(v)
- if add_upload_info and files and (msg.role in (SYSTEM, USER)):
- if lang == 'auto':
- has_zh = has_chinese_chars(msg)
- else:
- has_zh = (lang == 'zh')
- upload = []
- for f in [get_basename_from_url(f) for f in files]:
- if is_image(f):
- if has_zh:
- upload.append(f'')
- else:
- upload.append(f'')
- else:
- if has_zh:
- upload.append(f'[文件]({f})')
- else:
- upload.append(f'[file]({f})')
- upload = ' '.join(upload)
- if has_zh:
- upload = f'(上传了 {upload})\n\n'
- else:
- upload = f'(Uploaded {upload})\n\n'
- # Check and avoid adding duplicate upload info
- upload_info_already_added = False
- for item in content:
- if item.text and (upload in item.text):
- upload_info_already_added = True
- if not upload_info_already_added:
- content = [ContentItem(text=upload)] + content
- else:
- raise TypeError
- msg = Message(
- role=msg.role,
- content=content,
- name=msg.name if msg.role == FUNCTION else None,
- function_call=msg.function_call,
- )
- return msg
- def format_as_text_message(
- msg: Message,
- add_upload_info: bool,
- lang: Literal['auto', 'en', 'zh'] = 'auto',
- ) -> Message:
- msg = format_as_multimodal_message(msg, add_upload_info=add_upload_info, lang=lang)
- text = ''
- for item in msg.content:
- if item.type == 'text':
- text += item.value
- msg.content = text
- return msg
- def extract_text_from_message(
- msg: Message,
- add_upload_info: bool,
- lang: Literal['auto', 'en', 'zh'] = 'auto',
- ) -> str:
- if isinstance(msg.content, list):
- text = format_as_text_message(msg, add_upload_info=add_upload_info, lang=lang).content
- elif isinstance(msg.content, str):
- text = msg.content
- else:
- raise TypeError(f'List of str or str expected, but received {type(msg.content).__name__}.')
- return text.strip()
- def extract_files_from_messages(messages: List[Message], include_images: bool) -> List[str]:
- files = []
- for msg in messages:
- if isinstance(msg.content, list):
- for item in msg.content:
- if item.file and item.file not in files:
- files.append(item.file)
- if include_images and item.image and item.image not in files:
- files.append(item.image)
- return files
- def merge_generate_cfgs(base_generate_cfg: Optional[dict], new_generate_cfg: Optional[dict]) -> dict:
- generate_cfg: dict = copy.deepcopy(base_generate_cfg or {})
- if new_generate_cfg:
- for k, v in new_generate_cfg.items():
- if k == 'stop':
- stop = generate_cfg.get('stop', [])
- stop = stop + [s for s in v if s not in stop]
- generate_cfg['stop'] = stop
- else:
- generate_cfg[k] = v
- return generate_cfg
- def build_text_completion_prompt(
- messages: List[Message],
- allow_special: bool = False,
- default_system: str = DEFAULT_SYSTEM_MESSAGE,
- ) -> str:
- im_start = '<|im_start|>'
- im_end = '<|im_end|>'
- if messages[0].role == SYSTEM:
- sys = messages[0].content
- assert isinstance(sys, str)
- prompt = f'{im_start}{SYSTEM}\n{sys}{im_end}'
- messages = messages[1:]
- else:
- prompt = f'{im_start}{SYSTEM}\n{default_system}{im_end}'
- # Make sure we are completing the chat in the tone of the assistant
- if messages[-1].role != ASSISTANT:
- messages = messages + [Message(ASSISTANT, '')]
- for msg in messages:
- assert isinstance(msg.content, str)
- content = msg.content.lstrip('\n').rstrip()
- if allow_special:
- assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION)
- if msg.function_call:
- assert msg.role == ASSISTANT
- tool_call = msg.function_call.arguments
- try:
- tool_call = {'name': msg.function_call.name, 'arguments': json.loads(tool_call)}
- tool_call = json.dumps(tool_call, ensure_ascii=False, indent=2)
- except json.decoder.JSONDecodeError:
- tool_call = '{"name": "' + msg.function_call.name + '", "arguments": ' + tool_call + '}'
- if content:
- content += '\n'
- content += f'<tool_call>\n{tool_call}\n</tool_call>'
- else:
- assert msg.role in (USER, ASSISTANT)
- assert msg.function_call is None
- prompt += f'\n{im_start}{msg.role}\n{content}{im_end}'
- assert prompt.endswith(im_end)
- prompt = prompt[:-len(im_end)]
- return prompt
- def encode_image_as_base64(path: str, max_short_side_length: int = -1) -> str:
- from PIL import Image
- image = Image.open(path)
- if (max_short_side_length > 0) and (min(image.size) > max_short_side_length):
- ori_size = image.size
- image = resize_image(image, short_side_length=max_short_side_length)
- logger.debug(f'Image "{path}" resized from {ori_size} to {image.size}.')
- image = image.convert(mode='RGB')
- buffered = BytesIO()
- image.save(buffered, format='JPEG')
- return 'data:image/jpeg;base64,' + base64.b64encode(buffered.getvalue()).decode('utf-8')
- def load_image_from_base64(image_base64: Union[bytes, str]):
- from PIL import Image
- image = Image.open(BytesIO(base64.b64decode(image_base64)))
- image.load()
- return image
- def resize_image(img, short_side_length: int = 1080):
- from PIL import Image
- assert isinstance(img, Image.Image)
- width, height = img.size
- if width <= height:
- new_width = short_side_length
- new_height = int((short_side_length / width) * height)
- else:
- new_height = short_side_length
- new_width = int((short_side_length / height) * width)
- resized_img = img.resize((new_width, new_height), resample=Image.Resampling.BILINEAR)
- return resized_img
- def get_last_usr_msg_idx(messages: List[Union[dict, Message]]) -> int:
- i = len(messages) - 1
- while (i >= 0) and (messages[i]['role'] != 'user'):
- i -= 1
- assert i >= 0, messages
- assert messages[i]['role'] == 'user'
- return i
|