1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import re
- from typing import Iterator, List
- import json5
- from qwen_agent import Agent
- from qwen_agent.agents.assistant import Assistant
- from qwen_agent.agents.writing import ExpandWriting, OutlineWriting
- from qwen_agent.llm.schema import ASSISTANT, CONTENT, USER, Message
- default_plan = """{"action1": "summarize", "action2": "outline", "action3": "expand"}"""
- def is_roman_numeral(s):
- pattern = r'^(I|V|X|L|C|D|M)+'
- match = re.match(pattern, s)
- return match is not None
- class WriteFromScratch(Agent):
- def _run(self, messages: List[Message], knowledge: str = '', lang: str = 'en') -> Iterator[List[Message]]:
- response = [Message(ASSISTANT, f'>\n> Use Default plans: \n{default_plan}')]
- yield response
- res_plans = json5.loads(default_plan)
- summ = ''
- outline = ''
- for plan_id in sorted(res_plans.keys()):
- plan = res_plans[plan_id]
- if plan == 'summarize':
- response.append(Message(ASSISTANT, '>\n> Summarize Browse Content: \n'))
- yield response
- if lang == 'zh':
- user_request = '总结参考资料的主要内容'
- elif lang == 'en':
- user_request = 'Summarize the main content of reference materials.'
- else:
- raise NotImplementedError
- sum_agent = Assistant(llm=self.llm)
- res_sum = sum_agent.run(messages=[Message(USER, user_request)], knowledge=knowledge, lang=lang)
- chunk = None
- for chunk in res_sum:
- yield response + chunk
- if chunk:
- response.extend(chunk)
- summ = chunk[-1][CONTENT]
- elif plan == 'outline':
- response.append(Message(ASSISTANT, '>\n> Generate Outline: \n'))
- yield response
- otl_agent = OutlineWriting(llm=self.llm)
- res_otl = otl_agent.run(messages=messages, knowledge=summ, lang=lang)
- chunk = None
- for chunk in res_otl:
- yield response + chunk
- if chunk:
- response.extend(chunk)
- outline = chunk[-1][CONTENT]
- elif plan == 'expand':
- response.append(Message(ASSISTANT, '>\n> Writing Text: \n'))
- yield response
- outline_list_all = outline.split('\n')
- outline_list = []
- for x in outline_list_all:
- if is_roman_numeral(x):
- outline_list.append(x)
- otl_num = len(outline_list)
- for i, v in enumerate(outline_list):
- response.append(Message(ASSISTANT, '>\n# '))
- yield response
- index = i + 1
- capture = v.strip()
- capture_later = ''
- if i < otl_num - 1:
- capture_later = outline_list[i + 1].strip()
- exp_agent = ExpandWriting(llm=self.llm)
- res_exp = exp_agent.run(
- messages=messages,
- knowledge=knowledge,
- outline=outline,
- index=str(index),
- capture=capture,
- capture_later=capture_later,
- lang=lang,
- )
- chunk = None
- for chunk in res_exp:
- yield response + chunk
- if chunk:
- response.extend(chunk)
- else:
- pass
|