tools.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. import asyncpg
  2. from qwen_agent.config import db_config
  3. from qwen_agent.tools.image_gen import image_gen # NOQA
  4. import json
  5. from langchain.utilities import SQLDatabase
  6. import pymysql
  7. from elasticsearch import AsyncElasticsearch, Elasticsearch
  8. import traceback
  9. from qwen_agent.tools.query_understanding import gen_company_search_dsl
  10. import pymysql
  11. from datetime import datetime
  12. from qwen_agent.tools.code_interpreter import code_interpreter
  13. import re
  14. import asyncio
  15. import aiomysql
  16. import tritonclient.http as httpclient
  17. import numpy as np
  18. from qwen_agent.utils.util import toJson
  19. pymysql.install_as_MySQLdb()
  20. def simplify_chart(option):
  21. result = {'series': [], 'xAxis': [], 'yAxis': [], 'title': [], 'legend': []}
  22. json_data = json.loads(option)
  23. for key_level1 in ['series', 'xAxis', 'yAxis', 'title', 'legend']:
  24. for item in json_data.get(key_level1, []):
  25. tmp = {}
  26. for key, value in item.items():
  27. if key in ['type', 'data', 'text', 'name', 'label']:
  28. tmp[key] = value
  29. result[key_level1].append(tmp)
  30. # 对多个series 的进行处理 (折线图或者柱状图的情况下)
  31. if len(result['series']) > 1 and result['series'][0]['type'] in ('line', 'bar'):
  32. i = 0
  33. selected = {}
  34. for series in result['series']:
  35. selected[series['name']] = bool(i == 0)
  36. i += 1
  37. result['legend'] = dict(selected=selected)
  38. return json.dumps(result, ensure_ascii=False)
  39. def format_json(data):
  40. # 遍历字典的每一个键值对
  41. for key in data:
  42. if '金额' in key or '预算' in key:
  43. # 检查是否所有元素都可以转换为浮点数
  44. try:
  45. values = [float(value) if value != 'None' else 0 for value in data[key]]
  46. # 如果可以,那么对每个元素进行单位转换
  47. max_value = max(values)
  48. if max_value >= 1e8:
  49. data[key] = [f"{value / 1e8:.2f}亿" for value in values]
  50. elif max_value >= 1e4:
  51. data[key] = [f"{value / 1e4:.2f}万" for value in values]
  52. else:
  53. data[key] = [f"{value}" for value in values]
  54. except ValueError:
  55. # 如果不能,那么保留原来的值
  56. pass
  57. return data
  58. class MySQLSearcher():
  59. def __init__(self, db_name='lianqiai_db') -> None:
  60. # 建立数据库连接
  61. self.db_name = db_name
  62. self.connect()
  63. # print(f'db name: {self.db_name}')
  64. def connect(self):
  65. self.db = pymysql.connect(
  66. host='xx.aliyuncs.com',
  67. user='xxx',
  68. password='xxx',
  69. db=self.db_name,
  70. port=3306
  71. )
  72. self.cursor = self.db.cursor()
  73. def format_result(self):
  74. result = self.cursor.fetchmany(20)
  75. if len(result) == 0:
  76. return ''
  77. headers = [column[0] for column in self.cursor.description]
  78. # print(f"headers:{headers}")
  79. # 使用tabulate创建Markdown格式的表格
  80. # markdown_table = tabulate(result, headers, tablefmt="pipe")
  81. json_data = {}
  82. for i, header in enumerate(headers):
  83. json_data[header] = [f"{row[i]}" if len(f"{row[i]}") > 0 else '未知' for row in result]
  84. json_data = format_json(json_data)
  85. return json.dumps(json_data, ensure_ascii=False)
  86. def run(self, command):
  87. try:
  88. self.cursor.execute(command)
  89. result = self.format_result()
  90. except Exception as e:
  91. try:
  92. self.connect()
  93. self.cursor = self.db.cursor()
  94. self.cursor.execute(command)
  95. result = self.format_result()
  96. except Exception as ex:
  97. return f"ERROR:{traceback.format_exc()[-200:]}", False
  98. return result, True
  99. def _close(self):
  100. self.cursor.close()
  101. self.db.commit()
  102. self.db.close()
  103. class AsyncMySQLSearcher():
  104. def __init__(self, db_name='lianqiai_db') -> None:
  105. # 建立数据库连接
  106. self.db_name = db_name
  107. self.db = None
  108. self.pool = None
  109. # await self.connect()
  110. # print(f'db name: {self.db_name}')
  111. async def connect(self):
  112. self.db = await aiomysql.connect(
  113. # host='rm-bp13i5ci7o9ev1241ho.mysql.rds.aliyuncs.com',
  114. host='xxxx.ads.aliyuncs.com', # analysisDB
  115. user='xxxx',
  116. password='xxxxx',
  117. db=self.db_name,
  118. port=3306
  119. )
  120. # self.db = await aiomysql.connect(
  121. # host='10.10.0.10',
  122. # user='root',
  123. # password='Lianqiai',
  124. # db='lianqi_db',
  125. # port=13306
  126. # )
  127. self.cursor = await self.db.cursor()
  128. async def register(self):
  129. '''
  130. 初始化,获取数据库连接池
  131. :return:
  132. '''
  133. try:
  134. print("start to connect db!")
  135. self.pool = await aiomysql.create_pool(host='amv-bp1sk343446b8u0d100001808o.ads.aliyuncs.com', port=3306,
  136. user='lianqi_admin', password='(lianqi666666)',
  137. db='lianqiai_db')
  138. print("succeed to connect db!")
  139. except asyncio.CancelledError:
  140. raise asyncio.CancelledError
  141. except Exception as ex:
  142. print("mysql数据库连接失败:{}".format(ex.args[0]))
  143. async def run(self, sql):
  144. if not self.pool:
  145. await self.register()
  146. '''
  147. 查询, 一般流程是首先获取连接,光标,获取数据之后,则需要释放连接
  148. :param pool:
  149. :return:
  150. '''
  151. # conn, cur = await self.getCurosr()
  152. try:
  153. async with self.pool.acquire() as conn:
  154. async with conn.cursor() as cur:
  155. await cur.execute(sql)
  156. result = await cur.fetchmany(30)
  157. return await self.format_result(result, cur), True
  158. except Exception as e:
  159. return f"ERROR:{traceback.format_exc()[-200:]}", False
  160. async def format_result(self, result, cursor):
  161. if len(result) == 0:
  162. return ''
  163. headers = [column[0] for column in cursor.description]
  164. json_data = {}
  165. for i, header in enumerate(headers):
  166. json_data[header] = [f"{row[i]}" if len(f"{row[i]}") > 0 else '未知' for row in result]
  167. json_data = format_json(json_data)
  168. # return json.dumps(json_data,ensure_ascii=False)
  169. row_result = []
  170. # 将数据转换为列表
  171. for i in range(len(json_data[headers[0]])):
  172. tmp_dict = {}
  173. for head in headers:
  174. tmp_dict[head] = json_data[head][i]
  175. row_result.append(tmp_dict)
  176. return json.dumps(row_result, ensure_ascii=False)
  177. def _close(self):
  178. self.cursor.close()
  179. self.db.commit()
  180. self.db.close()
  181. class AsyncPGSearcher:
  182. def __init__(self, db_name='pg') -> None:
  183. # 建立数据库连接
  184. self.db_name = db_name
  185. self.db = None
  186. self.pool = None
  187. # await self.connect()
  188. # print(f'db name: {self.db_name}')
  189. async def connect(self):
  190. pg_config = db_config.db_list.get(self.db_name)
  191. print("pg_config:" %pg_config);
  192. self.db = await asyncpg.connect(
  193. host=pg_config.get("host"), # analysisDB
  194. user=pg_config.get("user"),
  195. password=pg_config.get("password"),
  196. database=pg_config.get("database"),
  197. port=pg_config.get("port")
  198. )
  199. self.cursor = await self.db.cursor()
  200. async def register(self):
  201. '''
  202. 初始化,获取数据库连接池
  203. :return:
  204. '''
  205. try:
  206. print("start to connect db!")
  207. pg_config = db_config.db_list.get(self.db_name)
  208. self.pool = await asyncpg.create_pool(host=pg_config.get("host"),
  209. user=pg_config.get("user"),
  210. password=pg_config.get("password"),
  211. database=pg_config.get("database"),
  212. port=pg_config.get("port"))
  213. print("succeed to connect db!")
  214. except asyncio.CancelledError:
  215. raise asyncio.CancelledError
  216. except Exception as ex:
  217. print("pg数据库连接失败:{}".format(ex.args[0]))
  218. async def run(self, sql):
  219. if not self.pool:
  220. await self.register()
  221. '''
  222. 查询, 一般流程是首先获取连接,光标,获取数据之后,则需要释放连接
  223. :param pool:
  224. :return:
  225. '''
  226. # conn, cur = await self.getCurosr()
  227. try:
  228. result = await self.pool.fetch(sql)
  229. return await self.format_result(result), True
  230. except Exception as e:
  231. # return f"ERROR:{json.dumps(e, ensure_ascii=False)}", False
  232. return f"ERROR:{traceback.format_exc()[-200:]}", False
  233. async def format_result(self, result):
  234. if len(result) == 0:
  235. return ''
  236. result_json_obj = []
  237. for record in result:
  238. record_dict = dict(record)
  239. result_json_obj.append(record_dict)
  240. json_result = toJson(result_json_obj)
  241. return toJson(result_json_obj)
  242. def _close(self):
  243. self.cursor.close()
  244. self.db.commit()
  245. self.db.close()
  246. def format_json(data):
  247. # 遍历字典的每一个键值对
  248. for key in data:
  249. if '金额' in key or '预算' in key:
  250. # 检查是否所有元素都可以转换为浮点数
  251. try:
  252. values = [float(value) if value != 'None' else 0 for value in data[key]]
  253. # 如果可以,那么对每个元素进行单位转换
  254. max_value = max(values)
  255. if max_value >= 1e8:
  256. data[key] = [f"{value / 1e8:.2f}亿" for value in values]
  257. elif max_value >= 1e4:
  258. data[key] = [f"{value / 1e4:.2f}万" for value in values]
  259. else:
  260. data[key] = [f"{value}" for value in values]
  261. except ValueError:
  262. # 如果不能,那么保留原来的值
  263. pass
  264. return data
  265. # db = MySQLSearcher()
  266. # es = AsyncElasticsearch([{'host':'xxxxx.public.elasticsearch.aliyuncs.com', 'port': 9200,'scheme': "http"}],
  267. # http_auth=('xxx','xxxx!'),timeout=3600)
  268. # async_db = AsyncMySQLSearcher()
  269. # asyncio.run(async_db.register())
  270. async_db = AsyncPGSearcher('pg')
  271. async_xzdb = AsyncPGSearcher('xzpg')
  272. def modify_sql(sql_query):
  273. # 使用正则表达式寻找 GROUP BY 和其后面的内容,直到遇到 HAVING, ORDER BY 或者字符串结束
  274. match = re.search(r'(GROUP BY|group by)(.*?)(order by|having|HAVING|ORDER BY|;|$)', sql_query, re.IGNORECASE)
  275. target_table = "agent_bidding_history_detail_all"
  276. # 如果找到匹配项
  277. if match:
  278. # 提取 GROUP BY 后面的字段,去掉多余的空格,并分割成列表
  279. fields = match.group(2).strip().split(',')
  280. for field in fields:
  281. if field.strip() == '招标单位':
  282. target_table = 'agent_bidding_history_detail_by_ifb_new'
  283. break
  284. elif field.strip() == '中标单位':
  285. target_table = 'agent_bidding_history_detail_by_wtb_new'
  286. break
  287. elif field.strip() == '招标产品':
  288. target_table = 'agent_bidding_history_detail_by_prod_new'
  289. break
  290. sql_query = sql_query.replace('agent_bidding_history_detail_all', target_table)
  291. return sql_query
  292. # return [field.strip() for field in fields] # 去掉每个字段前后的空格
  293. async def call_plugin(plugin_name: str, plugin_args: str):
  294. # plugin_args = plugin_args.replace("```json",'')
  295. # plugin_args = plugin_args.replace("```",'').rstrip("`\n")
  296. # if plugin_args.startswith('```'):
  297. # plugin_args=extract_json(plugin_args)
  298. # if plugin_args.startswith('json'):
  299. # plugin_args = plugin_args.replace("json",'')
  300. import time
  301. print("plugin_args", plugin_args)
  302. print("plugin_name", plugin_name)
  303. if plugin_args.startswith('```'):
  304. triple_match = re.search(r'```[^\n]*\n(.+?)```', plugin_args, re.DOTALL)
  305. if triple_match:
  306. plugin_args = triple_match.group(1)
  307. print("plugin_args_clean:", plugin_args)
  308. if plugin_name == "TenderResultSqlAgent":
  309. try:
  310. # try:
  311. # if plugin_args.startswith('```'):
  312. # triple_match = re.search(r'```[^\n]*\n(.+?)```', plugin_args, re.DOTALL)
  313. # if triple_match:
  314. # plugin_args = triple_match.group(1)
  315. # sql_to_execute = json.loads(plugin_args)['sql_code']
  316. # else:
  317. # sql_to_execute = json.loads(plugin_args)['sql_code']
  318. sql_to_execute = json.loads(plugin_args)['sql_code']
  319. sql_to_execute = modify_sql(sql_to_execute)
  320. # except:
  321. # sql_to_execute = json.loads(plugin_args)['description']
  322. # sql_to_execute = sql_to_execute.replace(';','')
  323. # if plugin_args.startswith('sql'):
  324. # sql_to_execute = plugin_args.rstrip('sql')
  325. # else:
  326. # sql_to_execute = plugin_args
  327. print(f"sql_to_execute:{sql_to_execute}")
  328. print("plugin_name", plugin_name)
  329. #
  330. a = time.time()
  331. res_tuples = await async_db.run(sql_to_execute)
  332. print('SQL Time Cost:', time.time() - a)
  333. result, success = res_tuples
  334. if success and len(result) == 0:
  335. return '暂未查询到相关信息,可能是字段错误,请尝试调整查询条件,例如where条件中把“招标单位”改为“中标单位”', plugin_args, False
  336. return f"```json\n{result}\n```", plugin_args, success
  337. except:
  338. return f"ERROR:{traceback.format_exc()}", plugin_args, False
  339. elif plugin_name == 'search_ES':
  340. print(f"es plugin args: {plugin_args}")
  341. if plugin_args.startswith('```'):
  342. plugin_args = re.sub(r'(^```(json)?\n|\n```$)', '', plugin_args.strip())
  343. print("plugin_args", plugin_args)
  344. try:
  345. plugin_args = json.loads(plugin_args)
  346. if isinstance(plugin_args, list):
  347. plugin_args = plugin_args[0]
  348. org_query = ''
  349. if 'organization_name' in plugin_args:
  350. org_query = plugin_args['organization_name']
  351. elif "name" in plugin_args and plugin_args['name'] == "organization_name":
  352. org_query = plugin_args['value']
  353. body, _ = gen_company_search_dsl(org_query)
  354. # print(f'es to execute: {body}')
  355. result = await es.search(body=body, index='company_basic_info_index_new')
  356. # print(f'es result: {result}')
  357. names = []
  358. if result["hits"]['hits'][0]['_source']['companyname'] == org_query:
  359. return json.dumps({"ExactMatch": org_query}, ensure_ascii=False), plugin_args, True
  360. for h in result["hits"]['hits']:
  361. names.append(h['_source']['companyname'])
  362. result = json.dumps({"FuzzyMatch": names}, ensure_ascii=False)
  363. return f"```json\n{result}\n```", plugin_args, True
  364. except:
  365. return f"ERROR:{traceback.format_exc()}", plugin_args, False
  366. elif plugin_name == 'generate_table':
  367. try:
  368. # plot_args = json.loads(plugin_args.replace("\'","\""))
  369. # plot_args['series']['data'] = sorted(plot_args['series']['data'],key=lambda x: x[0])
  370. return plugin_args.replace("\'", "\""), plugin_args, True
  371. except:
  372. return f"ERROR:{traceback.format_exc()}", plugin_args, False
  373. elif plugin_name == 'generate_chart':
  374. try:
  375. print(f'code args: {plugin_args}')
  376. code_result = await code_interpreter(plugin_args)
  377. triple_match = re.search(r'```[^\n]*\n(.+?)```', code_result, re.DOTALL)
  378. # print(f'code results: {plugin_args}')
  379. if triple_match:
  380. text = triple_match.group(1)
  381. if code_result.startswith("error:"):
  382. return f"\n{text}", plugin_args, False
  383. else:
  384. print('code intere succ!', simplify_chart(text))
  385. return f"\n{simplify_chart(text)}\n", plugin_args, True
  386. return code_result, plugin_args, False
  387. except:
  388. return f"ERROR:{traceback.format_exc()}", plugin_args, False
  389. elif plugin_name == 'recall_words':
  390. try:
  391. client = httpclient.InferenceServerClient(url="10.10.0.11:20100", verbose=False, network_timeout=3000)
  392. input_tensors = [httpclient.InferInput("source", [1, 1],
  393. "BYTES")] # httpclient.InferInput("product_list", [1,1], "BYTES"),
  394. source = np.array([['bid_prods'.encode('utf-8')]], dtype=np.object_)
  395. input_tensors[0].set_data_from_numpy(source)
  396. words = json.loads(plugin_args)['mention']
  397. if isinstance(words, str):
  398. text = np.array([[words.encode('utf-8')]], dtype=np.object_)
  399. input_tensors.append(httpclient.InferInput("product_list", [1, 1], "BYTES"))
  400. input_tensors[1].set_data_from_numpy(text)
  401. words = [words]
  402. elif isinstance(words, list):
  403. text = np.array([[t.encode('utf-8')] for t in words], dtype=np.object_)
  404. input_tensors.append(httpclient.InferInput("product_list", [len(words), 1], "BYTES"))
  405. input_tensors[1].set_data_from_numpy(text)
  406. else:
  407. return 'plugin_args ERROR', plugin_args, False
  408. outputs = [
  409. httpclient.InferRequestedOutput("output_products"),
  410. ]
  411. results = client.infer(model_name="product_sim_recall", inputs=input_tensors, outputs=outputs)
  412. output_data = results.as_numpy("output_products")
  413. return '|'.join(list(set([data.decode('utf-8') for data in output_data] + words))), plugin_args, True
  414. except:
  415. return f"ERROR:{traceback.format_exc()}", plugin_args, False
  416. else:
  417. raise NotImplementedError
  418. tools_list = [
  419. {
  420. 'name_for_human': '查询招投标数据库',
  421. 'name_for_model': 'TenderResultSqlAgent',
  422. 'description_for_model': """
  423. 当需要连接MySQL数据库并执行一段sql时,请使用此功能。
  424. """
  425. + ' Format the arguments as a JSON object.',
  426. 'parameters': [{'name': 'sql_code', 'type': 'string',
  427. 'description': '合法的MySQL查询语言。不接受【select *】,必须使用【select xxx,yyy】'}]
  428. },
  429. {
  430. 'name_for_human': '获取准确的公司或单位名称',
  431. 'name_for_model': 'search_ES',
  432. 'description_for_model': '当需要获取公司的具体名称时,请使用此API。输入需要确认的公司或单位名称,返回与之相似的公司或单位名称'
  433. + ' Format the arguments as a JSON object.',
  434. # 'parameters': [{'name': 'dsl_json', 'type': 'string', 'description': 'Json格式的Elasticsearch的DSL查询语言,可以match的字段有:【companyname】'}]
  435. 'parameters': [{'name': 'organization_name', 'type': 'string', 'description': '需要确认的公司或单位名称'}]
  436. },
  437. {
  438. 'name_for_human': '生成表格',
  439. 'name_for_model': 'generate_table',
  440. 'description_for_model': '[生成表格]用于将数据以表格方式呈现。输入HTML格式的表格数据,返回生成的表格url。'
  441. + ' Format the arguments as a JSON object.',
  442. 'parameters': [
  443. {
  444. 'name': 'HTML_table',
  445. 'description': """HTML格式的表格数据,例如:<table><tr><th>xxx</th><th>nnnn</th></tr><tr><td>ssss</td><td>dddd</td></tr></table>""",
  446. 'required': True,
  447. 'schema': {'type': 'HTML'},
  448. }
  449. ]
  450. },
  451. {
  452. 'name_for_human': '生成图表',
  453. 'name_for_model': 'generate_chart',
  454. 'description_for_model': """[生成折线图/柱状图/散点图/饼图]是一个图像生成服务,用于生成折线图/柱状图/散点图。输入需求和数据,返回生成的图表的python代码。""",
  455. 'parameters': [{
  456. 'name': 'code',
  457. 'type': 'string',
  458. 'description': '待执行的pyecharts代码。'
  459. }]
  460. },
  461. {
  462. 'name_for_human': '语义召回',
  463. 'name_for_model': 'recall_words',
  464. 'description_for_model': """[语义召回]可以返回与输入文本相似的文本""",
  465. 'parameters': [{
  466. 'name': 'mention',
  467. 'type': 'list of string',
  468. 'description': "原始文本中需要召回的文本,通常是一个或多个词,构成一个list:['xxx','yyy']"
  469. }]
  470. }
  471. ]
  472. if __name__ == "__main__":
  473. code = """import json
  474. import pyecharts.options as opts
  475. from pyecharts.globals import ThemeType
  476. from pyecharts.charts import Line
  477. #数据:格式为Dict
  478. source = {
  479. '发布月份': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'],
  480. '招标次数': ['16', '19', '32', '20', '26', '14', '29', '22', '38', '21', '14', '6']
  481. }
  482. #折线图
  483. line = Line()
  484. line.add_xaxis(source['发布月份'])
  485. line.add_yaxis('招标次数', source['招标次数'])
  486. line.set_global_opts(title_opts=opts.TitleOpts(title='上海市青浦区徐泾镇人民政府招标公告发布频率'),
  487. xaxis_opts=opts.AxisOpts(name='月份'),
  488. yaxis_opts=opts.AxisOpts(name='招标次数'))
  489. option = line.dump_options()
  490. print(option)
  491. """
  492. print('-----\n', code_interpreter(code))
  493. # s = MySQLSearcher()
  494. # print(s.run("SELECT publishdate, COUNT(*) as bid_count FROM tmp_qwen_copilot_orgname_bidding_success WHERE bidding_org_name = '34部队' GROUP BY publishdate ORDER BY publishdate"))