write_from_scratch.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import re
  2. from typing import Iterator, List
  3. import json5
  4. from qwen_agent import Agent
  5. from qwen_agent.agents.assistant import Assistant
  6. from qwen_agent.agents.writing import ExpandWriting, OutlineWriting
  7. from qwen_agent.llm.schema import ASSISTANT, CONTENT, USER, Message
  8. default_plan = """{"action1": "summarize", "action2": "outline", "action3": "expand"}"""
  9. def is_roman_numeral(s):
  10. pattern = r'^(I|V|X|L|C|D|M)+'
  11. match = re.match(pattern, s)
  12. return match is not None
  13. class WriteFromScratch(Agent):
  14. def _run(self, messages: List[Message], knowledge: str = '', lang: str = 'en') -> Iterator[List[Message]]:
  15. response = [Message(ASSISTANT, f'>\n> Use Default plans: \n{default_plan}')]
  16. yield response
  17. res_plans = json5.loads(default_plan)
  18. summ = ''
  19. outline = ''
  20. for plan_id in sorted(res_plans.keys()):
  21. plan = res_plans[plan_id]
  22. if plan == 'summarize':
  23. response.append(Message(ASSISTANT, '>\n> Summarize Browse Content: \n'))
  24. yield response
  25. if lang == 'zh':
  26. user_request = '总结参考资料的主要内容'
  27. elif lang == 'en':
  28. user_request = 'Summarize the main content of reference materials.'
  29. else:
  30. raise NotImplementedError
  31. sum_agent = Assistant(llm=self.llm)
  32. res_sum = sum_agent.run(messages=[Message(USER, user_request)], knowledge=knowledge, lang=lang)
  33. chunk = None
  34. for chunk in res_sum:
  35. yield response + chunk
  36. if chunk:
  37. response.extend(chunk)
  38. summ = chunk[-1][CONTENT]
  39. elif plan == 'outline':
  40. response.append(Message(ASSISTANT, '>\n> Generate Outline: \n'))
  41. yield response
  42. otl_agent = OutlineWriting(llm=self.llm)
  43. res_otl = otl_agent.run(messages=messages, knowledge=summ, lang=lang)
  44. chunk = None
  45. for chunk in res_otl:
  46. yield response + chunk
  47. if chunk:
  48. response.extend(chunk)
  49. outline = chunk[-1][CONTENT]
  50. elif plan == 'expand':
  51. response.append(Message(ASSISTANT, '>\n> Writing Text: \n'))
  52. yield response
  53. outline_list_all = outline.split('\n')
  54. outline_list = []
  55. for x in outline_list_all:
  56. if is_roman_numeral(x):
  57. outline_list.append(x)
  58. otl_num = len(outline_list)
  59. for i, v in enumerate(outline_list):
  60. response.append(Message(ASSISTANT, '>\n# '))
  61. yield response
  62. index = i + 1
  63. capture = v.strip()
  64. capture_later = ''
  65. if i < otl_num - 1:
  66. capture_later = outline_list[i + 1].strip()
  67. exp_agent = ExpandWriting(llm=self.llm)
  68. res_exp = exp_agent.run(
  69. messages=messages,
  70. knowledge=knowledge,
  71. outline=outline,
  72. index=str(index),
  73. capture=capture,
  74. capture_later=capture_later,
  75. lang=lang,
  76. )
  77. chunk = None
  78. for chunk in res_exp:
  79. yield response + chunk
  80. if chunk:
  81. response.extend(chunk)
  82. else:
  83. pass