utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. import base64
  2. import copy
  3. import hashlib
  4. import json
  5. import os
  6. import re
  7. import shutil
  8. import signal
  9. import socket
  10. import sys
  11. import time
  12. import traceback
  13. import urllib.parse
  14. from io import BytesIO
  15. from typing import Any, List, Literal, Optional, Tuple, Union
  16. import json5
  17. import requests
  18. from qwen_agent.llm.schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, FUNCTION, SYSTEM, USER, ContentItem, Message
  19. from qwen_agent.log import logger
  20. def append_signal_handler(sig, handler):
  21. """
  22. Installs a new signal handler while preserving any existing handler.
  23. If an existing handler is present, it will be called _after_ the new handler.
  24. """
  25. old_handler = signal.getsignal(sig)
  26. if not callable(old_handler):
  27. old_handler = None
  28. if sig == signal.SIGINT:
  29. def old_handler(*args, **kwargs):
  30. raise KeyboardInterrupt
  31. elif sig == signal.SIGTERM:
  32. def old_handler(*args, **kwargs):
  33. raise SystemExit
  34. def new_handler(*args, **kwargs):
  35. handler(*args, **kwargs)
  36. if old_handler is not None:
  37. old_handler(*args, **kwargs)
  38. signal.signal(sig, new_handler)
  39. def get_local_ip() -> str:
  40. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  41. try:
  42. # doesn't even have to be reachable
  43. s.connect(('10.255.255.255', 1))
  44. ip = s.getsockname()[0]
  45. except Exception:
  46. ip = '127.0.0.1'
  47. finally:
  48. s.close()
  49. return ip
  50. def hash_sha256(text: str) -> str:
  51. hash_object = hashlib.sha256(text.encode())
  52. key = hash_object.hexdigest()
  53. return key
  54. def print_traceback(is_error: bool = True):
  55. tb = ''.join(traceback.format_exception(*sys.exc_info(), limit=3))
  56. if is_error:
  57. logger.error(tb)
  58. else:
  59. logger.warning(tb)
  60. CHINESE_CHAR_RE = re.compile(r'[\u4e00-\u9fff]')
  61. def has_chinese_chars(data: Any) -> bool:
  62. text = f'{data}'
  63. return bool(CHINESE_CHAR_RE.search(text))
  64. def has_chinese_messages(messages: List[Union[Message, dict]], check_roles: Tuple[str] = (SYSTEM, USER)) -> bool:
  65. for m in messages:
  66. if m['role'] in check_roles:
  67. if has_chinese_chars(m['content']):
  68. return True
  69. return False
  70. def get_basename_from_url(path_or_url: str) -> str:
  71. if re.match(r'^[A-Za-z]:\\', path_or_url):
  72. # "C:\\a\\b\\c" -> "C:/a/b/c"
  73. path_or_url = path_or_url.replace('\\', '/')
  74. # "/mnt/a/b/c" -> "c"
  75. # "https://github.com/here?k=v" -> "here"
  76. # "https://github.com/" -> ""
  77. basename = urllib.parse.urlparse(path_or_url).path
  78. basename = os.path.basename(basename)
  79. basename = urllib.parse.unquote(basename)
  80. basename = basename.strip()
  81. # "https://github.com/" -> "" -> "github.com"
  82. if not basename:
  83. basename = [x.strip() for x in path_or_url.split('/') if x.strip()][-1]
  84. return basename
  85. def is_http_url(path_or_url: str) -> bool:
  86. if path_or_url.startswith('https://') or path_or_url.startswith('http://'):
  87. return True
  88. return False
  89. def is_image(path_or_url: str) -> bool:
  90. filename = get_basename_from_url(path_or_url).lower()
  91. for ext in ['jpg', 'jpeg', 'png', 'webp']:
  92. if filename.endswith(ext):
  93. return True
  94. return False
  95. def sanitize_chrome_file_path(file_path: str) -> str:
  96. if os.path.exists(file_path):
  97. return file_path
  98. # Dealing with "file:///...":
  99. new_path = urllib.parse.urlparse(file_path)
  100. new_path = urllib.parse.unquote(new_path.path)
  101. new_path = sanitize_windows_file_path(new_path)
  102. if os.path.exists(new_path):
  103. return new_path
  104. return sanitize_windows_file_path(file_path)
  105. def sanitize_windows_file_path(file_path: str) -> str:
  106. # For Linux and macOS.
  107. if os.path.exists(file_path):
  108. return file_path
  109. # For native Windows, drop the leading '/' in '/C:/'
  110. win_path = file_path
  111. if win_path.startswith('/'):
  112. win_path = win_path[1:]
  113. if os.path.exists(win_path):
  114. return win_path
  115. # For Windows + WSL.
  116. if re.match(r'^[A-Za-z]:/', win_path):
  117. wsl_path = f'/mnt/{win_path[0].lower()}/{win_path[3:]}'
  118. if os.path.exists(wsl_path):
  119. return wsl_path
  120. # For native Windows, replace / with \.
  121. win_path = win_path.replace('/', '\\')
  122. if os.path.exists(win_path):
  123. return win_path
  124. return file_path
  125. def save_url_to_local_work_dir(url: str, save_dir: str, save_filename: str = '') -> str:
  126. if not save_filename:
  127. save_filename = get_basename_from_url(url)
  128. new_path = os.path.join(save_dir, save_filename)
  129. if os.path.exists(new_path):
  130. os.remove(new_path)
  131. logger.info(f'Downloading {url} to {new_path}...')
  132. start_time = time.time()
  133. if not is_http_url(url):
  134. url = sanitize_chrome_file_path(url)
  135. shutil.copy(url, new_path)
  136. else:
  137. headers = {
  138. 'User-Agent':
  139. 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
  140. }
  141. response = requests.get(url, headers=headers)
  142. if response.status_code == 200:
  143. with open(new_path, 'wb') as file:
  144. file.write(response.content)
  145. else:
  146. raise ValueError('Can not download this file. Please check your network or the file link.')
  147. end_time = time.time()
  148. logger.info(f'Finished downloading {url} to {new_path}. Time spent: {end_time - start_time} seconds.')
  149. return new_path
  150. def save_text_to_file(path: str, text: str) -> None:
  151. with open(path, 'w', encoding='utf-8') as fp:
  152. fp.write(text)
  153. def read_text_from_file(path: str) -> str:
  154. try:
  155. with open(path, 'r', encoding='utf-8') as file:
  156. file_content = file.read()
  157. except UnicodeDecodeError:
  158. print_traceback(is_error=False)
  159. from charset_normalizer import from_path
  160. results = from_path(path)
  161. file_content = str(results.best())
  162. return file_content
  163. def contains_html_tags(text: str) -> bool:
  164. pattern = r'<(p|span|div|li|html|script)[^>]*?'
  165. return bool(re.search(pattern, text))
  166. def get_content_type_by_head_request(path: str) -> str:
  167. try:
  168. response = requests.head(path, timeout=5)
  169. content_type = response.headers.get('Content-Type', '')
  170. return content_type
  171. except requests.RequestException:
  172. return 'unk'
  173. def get_file_type(path: str) -> Literal['pdf', 'docx', 'pptx', 'txt', 'html', 'csv', 'tsv', 'xlsx', 'xls', 'unk']:
  174. f_type = get_basename_from_url(path).split('.')[-1].lower()
  175. if f_type in ['pdf', 'docx', 'pptx', 'csv', 'tsv', 'xlsx', 'xls']:
  176. # Specially supported file types
  177. return f_type
  178. if is_http_url(path):
  179. # The HTTP header information for the response is obtained by making a HEAD request to the target URL,
  180. # where the Content-type field usually indicates the Type of Content to be returned
  181. content_type = get_content_type_by_head_request(path)
  182. if 'application/pdf' in content_type:
  183. return 'pdf'
  184. elif 'application/msword' in content_type:
  185. return 'docx'
  186. # Assuming that the URL is HTML by default,
  187. # because the file downloaded by the request may contain html tags
  188. return 'html'
  189. else:
  190. # Determine by reading local HTML file
  191. try:
  192. content = read_text_from_file(path)
  193. except Exception:
  194. print_traceback()
  195. return 'unk'
  196. if contains_html_tags(content):
  197. return 'html'
  198. else:
  199. return 'txt'
  200. def extract_urls(text: str) -> List[str]:
  201. pattern = re.compile(r'https?://\S+')
  202. urls = re.findall(pattern, text)
  203. return urls
  204. def extract_markdown_urls(md_text: str) -> List[str]:
  205. pattern = r'!?\[[^\]]*\]\(([^\)]+)\)'
  206. urls = re.findall(pattern, md_text)
  207. return urls
  208. def extract_code(text: str) -> str:
  209. # Match triple backtick blocks first
  210. triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
  211. if triple_match:
  212. text = triple_match.group(1)
  213. else:
  214. try:
  215. text = json5.loads(text)['code']
  216. except Exception:
  217. print_traceback(is_error=False)
  218. # If no code blocks found, return original text
  219. return text
  220. def json_loads(text: str) -> dict:
  221. text = text.strip('\n')
  222. if text.startswith('```') and text.endswith('\n```'):
  223. text = '\n'.join(text.split('\n')[1:-1])
  224. try:
  225. return json.loads(text)
  226. except json.decoder.JSONDecodeError as json_err:
  227. try:
  228. return json5.loads(text)
  229. except ValueError:
  230. raise json_err
  231. def json_dumps(obj: dict) -> str:
  232. return json.dumps(obj, ensure_ascii=False, indent=2)
  233. def format_as_multimodal_message(
  234. msg: Message,
  235. add_upload_info: bool,
  236. lang: Literal['auto', 'en', 'zh'] = 'auto',
  237. ) -> Message:
  238. assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION)
  239. content: List[ContentItem] = []
  240. if isinstance(msg.content, str): # if text content
  241. if msg.content:
  242. content = [ContentItem(text=msg.content)]
  243. elif isinstance(msg.content, list): # if multimodal content
  244. files = []
  245. for item in msg.content:
  246. k, v = item.get_type_and_value()
  247. if k == 'text':
  248. content.append(ContentItem(text=v))
  249. if k == 'image':
  250. content.append(item)
  251. if k in ('file', 'image'):
  252. # Move 'file' out of 'content' since it's not natively supported by models
  253. files.append(v)
  254. if add_upload_info and files and (msg.role in (SYSTEM, USER)):
  255. if lang == 'auto':
  256. has_zh = has_chinese_chars(msg)
  257. else:
  258. has_zh = (lang == 'zh')
  259. upload = []
  260. for f in [get_basename_from_url(f) for f in files]:
  261. if is_image(f):
  262. if has_zh:
  263. upload.append(f'![图片]({f})')
  264. else:
  265. upload.append(f'![image]({f})')
  266. else:
  267. if has_zh:
  268. upload.append(f'[文件]({f})')
  269. else:
  270. upload.append(f'[file]({f})')
  271. upload = ' '.join(upload)
  272. if has_zh:
  273. upload = f'(上传了 {upload})\n\n'
  274. else:
  275. upload = f'(Uploaded {upload})\n\n'
  276. # Check and avoid adding duplicate upload info
  277. upload_info_already_added = False
  278. for item in content:
  279. if item.text and (upload in item.text):
  280. upload_info_already_added = True
  281. if not upload_info_already_added:
  282. content = [ContentItem(text=upload)] + content
  283. else:
  284. raise TypeError
  285. msg = Message(
  286. role=msg.role,
  287. content=content,
  288. name=msg.name if msg.role == FUNCTION else None,
  289. function_call=msg.function_call,
  290. )
  291. return msg
  292. def format_as_text_message(
  293. msg: Message,
  294. add_upload_info: bool,
  295. lang: Literal['auto', 'en', 'zh'] = 'auto',
  296. ) -> Message:
  297. msg = format_as_multimodal_message(msg, add_upload_info=add_upload_info, lang=lang)
  298. text = ''
  299. for item in msg.content:
  300. if item.type == 'text':
  301. text += item.value
  302. msg.content = text
  303. return msg
  304. def extract_text_from_message(
  305. msg: Message,
  306. add_upload_info: bool,
  307. lang: Literal['auto', 'en', 'zh'] = 'auto',
  308. ) -> str:
  309. if isinstance(msg.content, list):
  310. text = format_as_text_message(msg, add_upload_info=add_upload_info, lang=lang).content
  311. elif isinstance(msg.content, str):
  312. text = msg.content
  313. else:
  314. raise TypeError(f'List of str or str expected, but received {type(msg.content).__name__}.')
  315. return text.strip()
  316. def extract_files_from_messages(messages: List[Message], include_images: bool) -> List[str]:
  317. files = []
  318. for msg in messages:
  319. if isinstance(msg.content, list):
  320. for item in msg.content:
  321. if item.file and item.file not in files:
  322. files.append(item.file)
  323. if include_images and item.image and item.image not in files:
  324. files.append(item.image)
  325. return files
  326. def merge_generate_cfgs(base_generate_cfg: Optional[dict], new_generate_cfg: Optional[dict]) -> dict:
  327. generate_cfg: dict = copy.deepcopy(base_generate_cfg or {})
  328. if new_generate_cfg:
  329. for k, v in new_generate_cfg.items():
  330. if k == 'stop':
  331. stop = generate_cfg.get('stop', [])
  332. stop = stop + [s for s in v if s not in stop]
  333. generate_cfg['stop'] = stop
  334. else:
  335. generate_cfg[k] = v
  336. return generate_cfg
  337. def build_text_completion_prompt(
  338. messages: List[Message],
  339. allow_special: bool = False,
  340. default_system: str = DEFAULT_SYSTEM_MESSAGE,
  341. ) -> str:
  342. im_start = '<|im_start|>'
  343. im_end = '<|im_end|>'
  344. if messages[0].role == SYSTEM:
  345. sys = messages[0].content
  346. assert isinstance(sys, str)
  347. prompt = f'{im_start}{SYSTEM}\n{sys}{im_end}'
  348. messages = messages[1:]
  349. else:
  350. prompt = f'{im_start}{SYSTEM}\n{default_system}{im_end}'
  351. # Make sure we are completing the chat in the tone of the assistant
  352. if messages[-1].role != ASSISTANT:
  353. messages = messages + [Message(ASSISTANT, '')]
  354. for msg in messages:
  355. assert isinstance(msg.content, str)
  356. content = msg.content.lstrip('\n').rstrip()
  357. if allow_special:
  358. assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION)
  359. if msg.function_call:
  360. assert msg.role == ASSISTANT
  361. tool_call = msg.function_call.arguments
  362. try:
  363. tool_call = {'name': msg.function_call.name, 'arguments': json.loads(tool_call)}
  364. tool_call = json.dumps(tool_call, ensure_ascii=False, indent=2)
  365. except json.decoder.JSONDecodeError:
  366. tool_call = '{"name": "' + msg.function_call.name + '", "arguments": ' + tool_call + '}'
  367. if content:
  368. content += '\n'
  369. content += f'<tool_call>\n{tool_call}\n</tool_call>'
  370. else:
  371. assert msg.role in (USER, ASSISTANT)
  372. assert msg.function_call is None
  373. prompt += f'\n{im_start}{msg.role}\n{content}{im_end}'
  374. assert prompt.endswith(im_end)
  375. prompt = prompt[:-len(im_end)]
  376. return prompt
  377. def encode_image_as_base64(path: str, max_short_side_length: int = -1) -> str:
  378. from PIL import Image
  379. image = Image.open(path)
  380. if (max_short_side_length > 0) and (min(image.size) > max_short_side_length):
  381. ori_size = image.size
  382. image = resize_image(image, short_side_length=max_short_side_length)
  383. logger.debug(f'Image "{path}" resized from {ori_size} to {image.size}.')
  384. image = image.convert(mode='RGB')
  385. buffered = BytesIO()
  386. image.save(buffered, format='JPEG')
  387. return 'data:image/jpeg;base64,' + base64.b64encode(buffered.getvalue()).decode('utf-8')
  388. def load_image_from_base64(image_base64: Union[bytes, str]):
  389. from PIL import Image
  390. image = Image.open(BytesIO(base64.b64decode(image_base64)))
  391. image.load()
  392. return image
  393. def resize_image(img, short_side_length: int = 1080):
  394. from PIL import Image
  395. assert isinstance(img, Image.Image)
  396. width, height = img.size
  397. if width <= height:
  398. new_width = short_side_length
  399. new_height = int((short_side_length / width) * height)
  400. else:
  401. new_height = short_side_length
  402. new_width = int((short_side_length / height) * width)
  403. resized_img = img.resize((new_width, new_height), resample=Image.Resampling.BILINEAR)
  404. return resized_img
  405. def get_last_usr_msg_idx(messages: List[Union[dict, Message]]) -> int:
  406. i = len(messages) - 1
  407. while (i >= 0) and (messages[i]['role'] != 'user'):
  408. i -= 1
  409. assert i >= 0, messages
  410. assert messages[i]['role'] == 'user'
  411. return i