tools.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. import asyncpg
  2. from qwen_agent.config import db_config
  3. from qwen_agent.tools.image_gen import image_gen # NOQA
  4. import json
  5. from langchain.utilities import SQLDatabase
  6. import pymysql
  7. from elasticsearch import AsyncElasticsearch, Elasticsearch
  8. import traceback
  9. from qwen_agent.tools.query_understanding import gen_company_search_dsl
  10. import pymysql
  11. from datetime import datetime
  12. from qwen_agent.tools.code_interpreter import code_interpreter
  13. import re
  14. import asyncio
  15. import aiomysql
  16. import tritonclient.http as httpclient
  17. import numpy as np
  18. from qwen_agent.utils.util import toJson
  19. pymysql.install_as_MySQLdb()
  20. def simplify_chart(option):
  21. result = {'series': [], 'xAxis': [], 'yAxis': [], 'title': [], 'legend': []}
  22. json_data = json.loads(option)
  23. for key_level1 in ['series', 'xAxis', 'yAxis', 'title', 'legend']:
  24. for item in json_data.get(key_level1, []):
  25. tmp = {}
  26. for key, value in item.items():
  27. if key in ['type', 'data', 'text', 'name', 'label']:
  28. tmp[key] = value
  29. result[key_level1].append(tmp)
  30. # 对多个series 的进行处理 (折线图或者柱状图的情况下)
  31. if len(result['series']) > 1 and result['series'][0]['type'] in ('line', 'bar'):
  32. i = 0
  33. selected = {}
  34. for series in result['series']:
  35. selected[series['name']] = bool(i == 0)
  36. i += 1
  37. result['legend'] = dict(selected=selected)
  38. return json.dumps(result, ensure_ascii=False)
  39. def format_json(data):
  40. # 遍历字典的每一个键值对
  41. for key in data:
  42. if '金额' in key or '预算' in key:
  43. # 检查是否所有元素都可以转换为浮点数
  44. try:
  45. values = [float(value) if value != 'None' else 0 for value in data[key]]
  46. # 如果可以,那么对每个元素进行单位转换
  47. max_value = max(values)
  48. if max_value >= 1e8:
  49. data[key] = [f"{value / 1e8:.2f}亿" for value in values]
  50. elif max_value >= 1e4:
  51. data[key] = [f"{value / 1e4:.2f}万" for value in values]
  52. else:
  53. data[key] = [f"{value}" for value in values]
  54. except ValueError:
  55. # 如果不能,那么保留原来的值
  56. pass
  57. return data
  58. class MySQLSearcher():
  59. def __init__(self, db_name='lianqiai_db') -> None:
  60. # 建立数据库连接
  61. self.db_name = db_name
  62. self.connect()
  63. # print(f'db name: {self.db_name}')
  64. def connect(self):
  65. self.db = pymysql.connect(
  66. host='xx.aliyuncs.com',
  67. user='xxx',
  68. password='xxx',
  69. db=self.db_name,
  70. port=3306
  71. )
  72. self.cursor = self.db.cursor()
  73. def format_result(self):
  74. result = self.cursor.fetchmany(20)
  75. if len(result) == 0:
  76. return ''
  77. headers = [column[0] for column in self.cursor.description]
  78. # print(f"headers:{headers}")
  79. # 使用tabulate创建Markdown格式的表格
  80. # markdown_table = tabulate(result, headers, tablefmt="pipe")
  81. json_data = {}
  82. for i, header in enumerate(headers):
  83. json_data[header] = [f"{row[i]}" if len(f"{row[i]}") > 0 else '未知' for row in result]
  84. json_data = format_json(json_data)
  85. return json.dumps(json_data, ensure_ascii=False)
  86. def run(self, command):
  87. try:
  88. self.cursor.execute(command)
  89. result = self.format_result()
  90. except Exception as e:
  91. try:
  92. self.connect()
  93. self.cursor = self.db.cursor()
  94. self.cursor.execute(command)
  95. result = self.format_result()
  96. except Exception as ex:
  97. return f"ERROR:{traceback.format_exc()[-200:]}", False
  98. return result, True
  99. def _close(self):
  100. self.cursor.close()
  101. self.db.commit()
  102. self.db.close()
  103. class AsyncMySQLSearcher():
  104. def __init__(self, db_name='lianqiai_db') -> None:
  105. # 建立数据库连接
  106. self.db_name = db_name
  107. self.db = None
  108. self.pool = None
  109. # await self.connect()
  110. # print(f'db name: {self.db_name}')
  111. async def connect(self):
  112. self.db = await aiomysql.connect(
  113. # host='rm-bp13i5ci7o9ev1241ho.mysql.rds.aliyuncs.com',
  114. host='xxxx.ads.aliyuncs.com', # analysisDB
  115. user='xxxx',
  116. password='xxxxx',
  117. db=self.db_name,
  118. port=3306
  119. )
  120. # self.db = await aiomysql.connect(
  121. # host='10.10.0.10',
  122. # user='root',
  123. # password='Lianqiai',
  124. # db='lianqi_db',
  125. # port=13306
  126. # )
  127. self.cursor = await self.db.cursor()
  128. async def register(self):
  129. '''
  130. 初始化,获取数据库连接池
  131. :return:
  132. '''
  133. try:
  134. print("start to connect db!")
  135. self.pool = await aiomysql.create_pool(host='amv-bp1sk343446b8u0d100001808o.ads.aliyuncs.com', port=3306,
  136. user='lianqi_admin', password='(lianqi666666)',
  137. db='lianqiai_db')
  138. print("succeed to connect db!")
  139. except asyncio.CancelledError:
  140. raise asyncio.CancelledError
  141. except Exception as ex:
  142. print("mysql数据库连接失败:{}".format(ex.args[0]))
  143. async def run(self, sql):
  144. if not self.pool:
  145. await self.register()
  146. '''
  147. 查询, 一般流程是首先获取连接,光标,获取数据之后,则需要释放连接
  148. :param pool:
  149. :return:
  150. '''
  151. # conn, cur = await self.getCurosr()
  152. try:
  153. async with self.pool.acquire() as conn:
  154. async with conn.cursor() as cur:
  155. await cur.execute(sql)
  156. result = await cur.fetchmany(30)
  157. return await self.format_result(result, cur), True
  158. except Exception as e:
  159. return f"ERROR:{traceback.format_exc()[-200:]}", False
  160. async def format_result(self, result, cursor):
  161. if len(result) == 0:
  162. return ''
  163. headers = [column[0] for column in cursor.description]
  164. json_data = {}
  165. for i, header in enumerate(headers):
  166. json_data[header] = [f"{row[i]}" if len(f"{row[i]}") > 0 else '未知' for row in result]
  167. json_data = format_json(json_data)
  168. # return json.dumps(json_data,ensure_ascii=False)
  169. row_result = []
  170. # 将数据转换为列表
  171. for i in range(len(json_data[headers[0]])):
  172. tmp_dict = {}
  173. for head in headers:
  174. tmp_dict[head] = json_data[head][i]
  175. row_result.append(tmp_dict)
  176. return json.dumps(row_result, ensure_ascii=False)
  177. def _close(self):
  178. self.cursor.close()
  179. self.db.commit()
  180. self.db.close()
  181. class AsyncPGSearcher:
  182. def __init__(self, db_name='pg') -> None:
  183. # 建立数据库连接
  184. self.db_name = db_name
  185. self.db = None
  186. self.pool = None
  187. # await self.connect()
  188. # print(f'db name: {self.db_name}')
  189. async def connect(self):
  190. pg_config = db_config.db_list.get(self.db_name)
  191. print("pg_config:" %pg_config);
  192. self.db = await asyncpg.connect(
  193. host=pg_config.get("host"), # analysisDB
  194. user=pg_config.get("user"),
  195. password=pg_config.get("password"),
  196. database=pg_config.get("database"),
  197. port=pg_config.get("port")
  198. )
  199. self.cursor = await self.db.cursor()
  200. async def register(self):
  201. '''
  202. 初始化,获取数据库连接池
  203. :return:
  204. '''
  205. try:
  206. print("start to connect db!")
  207. pg_config = db_config.db_list.get(self.db_name)
  208. self.pool = await asyncpg.create_pool(host=pg_config.get("host"),
  209. user=pg_config.get("user"),
  210. password=pg_config.get("password"),
  211. database=pg_config.get("database"),
  212. port=pg_config.get("port"))
  213. print("succeed to connect db!")
  214. except asyncio.CancelledError:
  215. raise asyncio.CancelledError
  216. except Exception as ex:
  217. print("pg数据库连接失败:{}".format(ex.args[0]))
  218. async def run(self, sql):
  219. if not self.pool:
  220. await self.register()
  221. '''
  222. 查询, 一般流程是首先获取连接,光标,获取数据之后,则需要释放连接
  223. :param pool:
  224. :return:
  225. '''
  226. # conn, cur = await self.getCurosr()
  227. try:
  228. result = await self.pool.fetch(sql)
  229. return await self.format_result(result), True
  230. except Exception as e:
  231. # return f"ERROR:{json.dumps(e, ensure_ascii=False)}", False
  232. return f"ERROR:{traceback.format_exc()[-200:]}", False
  233. async def format_result(self, result):
  234. if len(result) == 0:
  235. return ''
  236. result_json_obj = []
  237. for record in result:
  238. record_dict = dict(record)
  239. result_json_obj.append(record_dict)
  240. json_result = toJson(result_json_obj)
  241. return toJson(result_json_obj)
  242. def _close(self):
  243. self.cursor.close()
  244. self.db.commit()
  245. self.db.close()
  246. def format_json(data):
  247. # 遍历字典的每一个键值对
  248. for key in data:
  249. if '金额' in key or '预算' in key:
  250. # 检查是否所有元素都可以转换为浮点数
  251. try:
  252. values = [float(value) if value != 'None' else 0 for value in data[key]]
  253. # 如果可以,那么对每个元素进行单位转换
  254. max_value = max(values)
  255. if max_value >= 1e8:
  256. data[key] = [f"{value / 1e8:.2f}亿" for value in values]
  257. elif max_value >= 1e4:
  258. data[key] = [f"{value / 1e4:.2f}万" for value in values]
  259. else:
  260. data[key] = [f"{value}" for value in values]
  261. except ValueError:
  262. # 如果不能,那么保留原来的值
  263. pass
  264. return data
  265. # db = MySQLSearcher()
  266. # es = AsyncElasticsearch([{'host':'xxxxx.public.elasticsearch.aliyuncs.com', 'port': 9200,'scheme': "http"}],
  267. # http_auth=('xxx','xxxx!'),timeout=3600)
  268. # async_db = AsyncMySQLSearcher()
  269. # asyncio.run(async_db.register())
  270. async_db = AsyncPGSearcher('pg')
  271. def modify_sql(sql_query):
  272. # 使用正则表达式寻找 GROUP BY 和其后面的内容,直到遇到 HAVING, ORDER BY 或者字符串结束
  273. match = re.search(r'(GROUP BY|group by)(.*?)(order by|having|HAVING|ORDER BY|;|$)', sql_query, re.IGNORECASE)
  274. target_table = "agent_bidding_history_detail_all"
  275. # 如果找到匹配项
  276. if match:
  277. # 提取 GROUP BY 后面的字段,去掉多余的空格,并分割成列表
  278. fields = match.group(2).strip().split(',')
  279. for field in fields:
  280. if field.strip() == '招标单位':
  281. target_table = 'agent_bidding_history_detail_by_ifb_new'
  282. break
  283. elif field.strip() == '中标单位':
  284. target_table = 'agent_bidding_history_detail_by_wtb_new'
  285. break
  286. elif field.strip() == '招标产品':
  287. target_table = 'agent_bidding_history_detail_by_prod_new'
  288. break
  289. sql_query = sql_query.replace('agent_bidding_history_detail_all', target_table)
  290. return sql_query
  291. # return [field.strip() for field in fields] # 去掉每个字段前后的空格
  292. async def call_plugin(plugin_name: str, plugin_args: str):
  293. # plugin_args = plugin_args.replace("```json",'')
  294. # plugin_args = plugin_args.replace("```",'').rstrip("`\n")
  295. # if plugin_args.startswith('```'):
  296. # plugin_args=extract_json(plugin_args)
  297. # if plugin_args.startswith('json'):
  298. # plugin_args = plugin_args.replace("json",'')
  299. import time
  300. print("plugin_args", plugin_args)
  301. print("plugin_name", plugin_name)
  302. if plugin_args.startswith('```'):
  303. triple_match = re.search(r'```[^\n]*\n(.+?)```', plugin_args, re.DOTALL)
  304. if triple_match:
  305. plugin_args = triple_match.group(1)
  306. print("plugin_args_clean:", plugin_args)
  307. if plugin_name == "TenderResultSqlAgent":
  308. try:
  309. # try:
  310. # if plugin_args.startswith('```'):
  311. # triple_match = re.search(r'```[^\n]*\n(.+?)```', plugin_args, re.DOTALL)
  312. # if triple_match:
  313. # plugin_args = triple_match.group(1)
  314. # sql_to_execute = json.loads(plugin_args)['sql_code']
  315. # else:
  316. # sql_to_execute = json.loads(plugin_args)['sql_code']
  317. sql_to_execute = json.loads(plugin_args)['sql_code']
  318. sql_to_execute = modify_sql(sql_to_execute)
  319. # except:
  320. # sql_to_execute = json.loads(plugin_args)['description']
  321. # sql_to_execute = sql_to_execute.replace(';','')
  322. # if plugin_args.startswith('sql'):
  323. # sql_to_execute = plugin_args.rstrip('sql')
  324. # else:
  325. # sql_to_execute = plugin_args
  326. print(f"sql_to_execute:{sql_to_execute}")
  327. print("plugin_name", plugin_name)
  328. #
  329. a = time.time()
  330. res_tuples = await async_db.run(sql_to_execute)
  331. print('SQL Time Cost:', time.time() - a)
  332. result, success = res_tuples
  333. if success and len(result) == 0:
  334. return '暂未查询到相关信息,可能是字段错误,请尝试调整查询条件,例如where条件中把“招标单位”改为“中标单位”', plugin_args, False
  335. return f"```json\n{result}\n```", plugin_args, success
  336. except:
  337. return f"ERROR:{traceback.format_exc()}", plugin_args, False
  338. elif plugin_name == 'search_ES':
  339. print(f"es plugin args: {plugin_args}")
  340. if plugin_args.startswith('```'):
  341. plugin_args = re.sub(r'(^```(json)?\n|\n```$)', '', plugin_args.strip())
  342. print("plugin_args", plugin_args)
  343. try:
  344. plugin_args = json.loads(plugin_args)
  345. if isinstance(plugin_args, list):
  346. plugin_args = plugin_args[0]
  347. org_query = ''
  348. if 'organization_name' in plugin_args:
  349. org_query = plugin_args['organization_name']
  350. elif "name" in plugin_args and plugin_args['name'] == "organization_name":
  351. org_query = plugin_args['value']
  352. body, _ = gen_company_search_dsl(org_query)
  353. # print(f'es to execute: {body}')
  354. result = await es.search(body=body, index='company_basic_info_index_new')
  355. # print(f'es result: {result}')
  356. names = []
  357. if result["hits"]['hits'][0]['_source']['companyname'] == org_query:
  358. return json.dumps({"ExactMatch": org_query}, ensure_ascii=False), plugin_args, True
  359. for h in result["hits"]['hits']:
  360. names.append(h['_source']['companyname'])
  361. result = json.dumps({"FuzzyMatch": names}, ensure_ascii=False)
  362. return f"```json\n{result}\n```", plugin_args, True
  363. except:
  364. return f"ERROR:{traceback.format_exc()}", plugin_args, False
  365. elif plugin_name == 'generate_table':
  366. try:
  367. # plot_args = json.loads(plugin_args.replace("\'","\""))
  368. # plot_args['series']['data'] = sorted(plot_args['series']['data'],key=lambda x: x[0])
  369. return plugin_args.replace("\'", "\""), plugin_args, True
  370. except:
  371. return f"ERROR:{traceback.format_exc()}", plugin_args, False
  372. elif plugin_name == 'generate_chart':
  373. try:
  374. print(f'code args: {plugin_args}')
  375. code_result = await code_interpreter(plugin_args)
  376. triple_match = re.search(r'```[^\n]*\n(.+?)```', code_result, re.DOTALL)
  377. # print(f'code results: {plugin_args}')
  378. if triple_match:
  379. text = triple_match.group(1)
  380. if code_result.startswith("error:"):
  381. return f"\n{text}", plugin_args, False
  382. else:
  383. print('code intere succ!', simplify_chart(text))
  384. return f"\n{simplify_chart(text)}\n", plugin_args, True
  385. return code_result, plugin_args, False
  386. except:
  387. return f"ERROR:{traceback.format_exc()}", plugin_args, False
  388. elif plugin_name == 'recall_words':
  389. try:
  390. client = httpclient.InferenceServerClient(url="10.10.0.11:20100", verbose=False, network_timeout=3000)
  391. input_tensors = [httpclient.InferInput("source", [1, 1],
  392. "BYTES")] # httpclient.InferInput("product_list", [1,1], "BYTES"),
  393. source = np.array([['bid_prods'.encode('utf-8')]], dtype=np.object_)
  394. input_tensors[0].set_data_from_numpy(source)
  395. words = json.loads(plugin_args)['mention']
  396. if isinstance(words, str):
  397. text = np.array([[words.encode('utf-8')]], dtype=np.object_)
  398. input_tensors.append(httpclient.InferInput("product_list", [1, 1], "BYTES"))
  399. input_tensors[1].set_data_from_numpy(text)
  400. words = [words]
  401. elif isinstance(words, list):
  402. text = np.array([[t.encode('utf-8')] for t in words], dtype=np.object_)
  403. input_tensors.append(httpclient.InferInput("product_list", [len(words), 1], "BYTES"))
  404. input_tensors[1].set_data_from_numpy(text)
  405. else:
  406. return 'plugin_args ERROR', plugin_args, False
  407. outputs = [
  408. httpclient.InferRequestedOutput("output_products"),
  409. ]
  410. results = client.infer(model_name="product_sim_recall", inputs=input_tensors, outputs=outputs)
  411. output_data = results.as_numpy("output_products")
  412. return '|'.join(list(set([data.decode('utf-8') for data in output_data] + words))), plugin_args, True
  413. except:
  414. return f"ERROR:{traceback.format_exc()}", plugin_args, False
  415. else:
  416. raise NotImplementedError
  417. tools_list = [
  418. {
  419. 'name_for_human': '查询招投标数据库',
  420. 'name_for_model': 'TenderResultSqlAgent',
  421. 'description_for_model': """
  422. 当需要连接MySQL数据库并执行一段sql时,请使用此功能。
  423. """
  424. + ' Format the arguments as a JSON object.',
  425. 'parameters': [{'name': 'sql_code', 'type': 'string',
  426. 'description': '合法的MySQL查询语言。不接受【select *】,必须使用【select xxx,yyy】'}]
  427. },
  428. {
  429. 'name_for_human': '获取准确的公司或单位名称',
  430. 'name_for_model': 'search_ES',
  431. 'description_for_model': '当需要获取公司的具体名称时,请使用此API。输入需要确认的公司或单位名称,返回与之相似的公司或单位名称'
  432. + ' Format the arguments as a JSON object.',
  433. # 'parameters': [{'name': 'dsl_json', 'type': 'string', 'description': 'Json格式的Elasticsearch的DSL查询语言,可以match的字段有:【companyname】'}]
  434. 'parameters': [{'name': 'organization_name', 'type': 'string', 'description': '需要确认的公司或单位名称'}]
  435. },
  436. {
  437. 'name_for_human': '生成表格',
  438. 'name_for_model': 'generate_table',
  439. 'description_for_model': '[生成表格]用于将数据以表格方式呈现。输入HTML格式的表格数据,返回生成的表格url。'
  440. + ' Format the arguments as a JSON object.',
  441. 'parameters': [
  442. {
  443. 'name': 'HTML_table',
  444. 'description': """HTML格式的表格数据,例如:<table><tr><th>xxx</th><th>nnnn</th></tr><tr><td>ssss</td><td>dddd</td></tr></table>""",
  445. 'required': True,
  446. 'schema': {'type': 'HTML'},
  447. }
  448. ]
  449. },
  450. {
  451. 'name_for_human': '生成图表',
  452. 'name_for_model': 'generate_chart',
  453. 'description_for_model': """[生成折线图/柱状图/散点图/饼图]是一个图像生成服务,用于生成折线图/柱状图/散点图。输入需求和数据,返回生成的图表的python代码。""",
  454. 'parameters': [{
  455. 'name': 'code',
  456. 'type': 'string',
  457. 'description': '待执行的pyecharts代码。'
  458. }]
  459. },
  460. {
  461. 'name_for_human': '语义召回',
  462. 'name_for_model': 'recall_words',
  463. 'description_for_model': """[语义召回]可以返回与输入文本相似的文本""",
  464. 'parameters': [{
  465. 'name': 'mention',
  466. 'type': 'list of string',
  467. 'description': "原始文本中需要召回的文本,通常是一个或多个词,构成一个list:['xxx','yyy']"
  468. }]
  469. }
  470. ]
  471. if __name__ == "__main__":
  472. code = """import json
  473. import pyecharts.options as opts
  474. from pyecharts.globals import ThemeType
  475. from pyecharts.charts import Line
  476. #数据:格式为Dict
  477. source = {
  478. '发布月份': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'],
  479. '招标次数': ['16', '19', '32', '20', '26', '14', '29', '22', '38', '21', '14', '6']
  480. }
  481. #折线图
  482. line = Line()
  483. line.add_xaxis(source['发布月份'])
  484. line.add_yaxis('招标次数', source['招标次数'])
  485. line.set_global_opts(title_opts=opts.TitleOpts(title='上海市青浦区徐泾镇人民政府招标公告发布频率'),
  486. xaxis_opts=opts.AxisOpts(name='月份'),
  487. yaxis_opts=opts.AxisOpts(name='招标次数'))
  488. option = line.dump_options()
  489. print(option)
  490. """
  491. print('-----\n', code_interpreter(code))
  492. # s = MySQLSearcher()
  493. # print(s.run("SELECT publishdate, COUNT(*) as bid_count FROM tmp_qwen_copilot_orgname_bidding_success WHERE bidding_org_name = '34部队' GROUP BY publishdate ORDER BY publishdate"))