dialogue_retrieval_agent.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import datetime
  2. import os
  3. from typing import Iterator, List
  4. from qwen_agent.agents.assistant import Assistant
  5. from qwen_agent.llm.schema import SYSTEM, USER, ContentItem, Message
  6. from qwen_agent.settings import DEFAULT_WORKSPACE
  7. from qwen_agent.utils.utils import extract_text_from_message, save_text_to_file
  8. MAX_TRUNCATED_QUERY_LENGTH = 1000
  9. EXTRACT_QUERY_TEMPLATE_ZH = """<给定文本>
  10. {ref_doc}
  11. 上面的文本是包括一段材料和一个用户请求,这个请求一般在最开头或最末尾,请你帮我提取出那个请求,你不需要回答这个请求,只需要打印出用户的请求即可!"""
  12. EXTRACT_QUERY_TEMPLATE_EN = """<Given Text>
  13. {ref_doc}
  14. The text above includes a section of reference material and a user request. The user request is typically located either at the beginning or the end. Please extract that request for me. You do not need to answer the request, just print out the user's request!"""
  15. EXTRACT_QUERY_TEMPLATE = {'zh': EXTRACT_QUERY_TEMPLATE_ZH, 'en': EXTRACT_QUERY_TEMPLATE_EN}
  16. # TODO: merge to retrieval tool
  17. class DialogueRetrievalAgent(Assistant):
  18. """This is an agent for super long dialogue."""
  19. def _run(self,
  20. messages: List[Message],
  21. lang: str = 'en',
  22. session_id: str = '',
  23. **kwargs) -> Iterator[List[Message]]:
  24. """Process messages and response
  25. Answer questions by storing the long dialogue in a file
  26. and using the file retrieval approach to retrieve relevant information
  27. """
  28. assert messages and messages[-1].role == USER
  29. new_messages = []
  30. content = []
  31. for msg in messages[:-1]:
  32. if msg.role == SYSTEM:
  33. new_messages.append(msg)
  34. else:
  35. content.append(f'{msg.role}: {extract_text_from_message(msg, add_upload_info=True)}')
  36. # Process the newest user message
  37. text = extract_text_from_message(messages[-1], add_upload_info=False)
  38. if len(text) <= MAX_TRUNCATED_QUERY_LENGTH:
  39. query = text
  40. else:
  41. if len(text) <= MAX_TRUNCATED_QUERY_LENGTH * 2:
  42. latent_query = text
  43. else:
  44. latent_query = f'{text[:MAX_TRUNCATED_QUERY_LENGTH]} ... {text[-MAX_TRUNCATED_QUERY_LENGTH:]}'
  45. *_, last = self._call_llm(
  46. messages=[Message(role=USER, content=EXTRACT_QUERY_TEMPLATE[lang].format(ref_doc=latent_query))])
  47. query = last[-1].content
  48. # A little tricky: If the extracted query is different from the original query, it cannot be removed
  49. text = text.replace(query, '')
  50. content.append(text)
  51. # Save content as file: This file is related to the session and the time
  52. content = '\n'.join(content)
  53. file_path = os.path.join(DEFAULT_WORKSPACE, f'dialogue_history_{session_id}_{datetime.datetime.now()}.txt')
  54. save_text_to_file(file_path, content)
  55. new_content = [ContentItem(text=query), ContentItem(file=file_path)]
  56. if isinstance(messages[-1].content, list):
  57. for item in messages[-1].content:
  58. if item.file or item.image:
  59. new_content.append(item)
  60. new_messages.append(Message(role=USER, content=new_content))
  61. return super()._run(messages=new_messages, lang=lang, **kwargs)