import asyncpg from qwen_agent.config import db_config from qwen_agent.tools.image_gen import image_gen # NOQA import json from langchain.utilities import SQLDatabase import pymysql from elasticsearch import AsyncElasticsearch, Elasticsearch import traceback from qwen_agent.tools.query_understanding import gen_company_search_dsl import pymysql from datetime import datetime from qwen_agent.tools.code_interpreter import code_interpreter import re import asyncio import aiomysql import tritonclient.http as httpclient import numpy as np from qwen_agent.utils.util import toJson pymysql.install_as_MySQLdb() def simplify_chart(option): result = {'series': [], 'xAxis': [], 'yAxis': [], 'title': [], 'legend': []} json_data = json.loads(option) for key_level1 in ['series', 'xAxis', 'yAxis', 'title', 'legend']: for item in json_data.get(key_level1, []): tmp = {} for key, value in item.items(): if key in ['type', 'data', 'text', 'name', 'label']: tmp[key] = value result[key_level1].append(tmp) # 对多个series 的进行处理 (折线图或者柱状图的情况下) if len(result['series']) > 1 and result['series'][0]['type'] in ('line', 'bar'): i = 0 selected = {} for series in result['series']: selected[series['name']] = bool(i == 0) i += 1 result['legend'] = dict(selected=selected) return json.dumps(result, ensure_ascii=False) def format_json(data): # 遍历字典的每一个键值对 for key in data: if '金额' in key or '预算' in key: # 检查是否所有元素都可以转换为浮点数 try: values = [float(value) if value != 'None' else 0 for value in data[key]] # 如果可以,那么对每个元素进行单位转换 max_value = max(values) if max_value >= 1e8: data[key] = [f"{value / 1e8:.2f}亿" for value in values] elif max_value >= 1e4: data[key] = [f"{value / 1e4:.2f}万" for value in values] else: data[key] = [f"{value}" for value in values] except ValueError: # 如果不能,那么保留原来的值 pass return data class MySQLSearcher(): def __init__(self, db_name='lianqiai_db') -> None: # 建立数据库连接 self.db_name = db_name self.connect() # print(f'db name: {self.db_name}') def connect(self): self.db = pymysql.connect( host='xx.aliyuncs.com', user='xxx', password='xxx', db=self.db_name, port=3306 ) self.cursor = self.db.cursor() def format_result(self): result = self.cursor.fetchmany(20) if len(result) == 0: return '' headers = [column[0] for column in self.cursor.description] # print(f"headers:{headers}") # 使用tabulate创建Markdown格式的表格 # markdown_table = tabulate(result, headers, tablefmt="pipe") json_data = {} for i, header in enumerate(headers): json_data[header] = [f"{row[i]}" if len(f"{row[i]}") > 0 else '未知' for row in result] json_data = format_json(json_data) return json.dumps(json_data, ensure_ascii=False) def run(self, command): try: self.cursor.execute(command) result = self.format_result() except Exception as e: try: self.connect() self.cursor = self.db.cursor() self.cursor.execute(command) result = self.format_result() except Exception as ex: return f"ERROR:{traceback.format_exc()[-200:]}", False return result, True def _close(self): self.cursor.close() self.db.commit() self.db.close() class AsyncMySQLSearcher(): def __init__(self, db_name='lianqiai_db') -> None: # 建立数据库连接 self.db_name = db_name self.db = None self.pool = None # await self.connect() # print(f'db name: {self.db_name}') async def connect(self): self.db = await aiomysql.connect( # host='rm-bp13i5ci7o9ev1241ho.mysql.rds.aliyuncs.com', host='xxxx.ads.aliyuncs.com', # analysisDB user='xxxx', password='xxxxx', db=self.db_name, port=3306 ) # self.db = await aiomysql.connect( # host='10.10.0.10', # user='root', # password='Lianqiai', # db='lianqi_db', # port=13306 # ) self.cursor = await self.db.cursor() async def register(self): ''' 初始化,获取数据库连接池 :return: ''' try: print("start to connect db!") self.pool = await aiomysql.create_pool(host='amv-bp1sk343446b8u0d100001808o.ads.aliyuncs.com', port=3306, user='lianqi_admin', password='(lianqi666666)', db='lianqiai_db') print("succeed to connect db!") except asyncio.CancelledError: raise asyncio.CancelledError except Exception as ex: print("mysql数据库连接失败:{}".format(ex.args[0])) async def run(self, sql): if not self.pool: await self.register() ''' 查询, 一般流程是首先获取连接,光标,获取数据之后,则需要释放连接 :param pool: :return: ''' # conn, cur = await self.getCurosr() try: async with self.pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute(sql) result = await cur.fetchmany(30) return await self.format_result(result, cur), True except Exception as e: return f"ERROR:{traceback.format_exc()[-200:]}", False async def format_result(self, result, cursor): if len(result) == 0: return '' headers = [column[0] for column in cursor.description] json_data = {} for i, header in enumerate(headers): json_data[header] = [f"{row[i]}" if len(f"{row[i]}") > 0 else '未知' for row in result] json_data = format_json(json_data) # return json.dumps(json_data,ensure_ascii=False) row_result = [] # 将数据转换为列表 for i in range(len(json_data[headers[0]])): tmp_dict = {} for head in headers: tmp_dict[head] = json_data[head][i] row_result.append(tmp_dict) return json.dumps(row_result, ensure_ascii=False) def _close(self): self.cursor.close() self.db.commit() self.db.close() class AsyncPGSearcher: def __init__(self, db_name='pg') -> None: # 建立数据库连接 self.db_name = db_name self.db = None self.pool = None # await self.connect() # print(f'db name: {self.db_name}') async def connect(self): pg_config = db_config.db_list.get(self.db_name) print("pg_config:" %pg_config); self.db = await asyncpg.connect( host=pg_config.get("host"), # analysisDB user=pg_config.get("user"), password=pg_config.get("password"), database=pg_config.get("database"), port=pg_config.get("port") ) self.cursor = await self.db.cursor() async def register(self): ''' 初始化,获取数据库连接池 :return: ''' try: print("start to connect db!") pg_config = db_config.db_list.get(self.db_name) self.pool = await asyncpg.create_pool(host=pg_config.get("host"), user=pg_config.get("user"), password=pg_config.get("password"), database=pg_config.get("database"), port=pg_config.get("port")) print("succeed to connect db!") except asyncio.CancelledError: raise asyncio.CancelledError except Exception as ex: print("pg数据库连接失败:{}".format(ex.args[0])) async def run(self, sql): if not self.pool: await self.register() ''' 查询, 一般流程是首先获取连接,光标,获取数据之后,则需要释放连接 :param pool: :return: ''' # conn, cur = await self.getCurosr() try: result = await self.pool.fetch(sql) return await self.format_result(result), True except Exception as e: # return f"ERROR:{json.dumps(e, ensure_ascii=False)}", False return f"ERROR:{traceback.format_exc()[-200:]}", False async def format_result(self, result): if len(result) == 0: return '' result_json_obj = [] for record in result: record_dict = dict(record) result_json_obj.append(record_dict) json_result = toJson(result_json_obj) return toJson(result_json_obj) def _close(self): self.cursor.close() self.db.commit() self.db.close() def format_json(data): # 遍历字典的每一个键值对 for key in data: if '金额' in key or '预算' in key: # 检查是否所有元素都可以转换为浮点数 try: values = [float(value) if value != 'None' else 0 for value in data[key]] # 如果可以,那么对每个元素进行单位转换 max_value = max(values) if max_value >= 1e8: data[key] = [f"{value / 1e8:.2f}亿" for value in values] elif max_value >= 1e4: data[key] = [f"{value / 1e4:.2f}万" for value in values] else: data[key] = [f"{value}" for value in values] except ValueError: # 如果不能,那么保留原来的值 pass return data # db = MySQLSearcher() # es = AsyncElasticsearch([{'host':'xxxxx.public.elasticsearch.aliyuncs.com', 'port': 9200,'scheme': "http"}], # http_auth=('xxx','xxxx!'),timeout=3600) # async_db = AsyncMySQLSearcher() # asyncio.run(async_db.register()) async_db = AsyncPGSearcher('pg') def modify_sql(sql_query): # 使用正则表达式寻找 GROUP BY 和其后面的内容,直到遇到 HAVING, ORDER BY 或者字符串结束 match = re.search(r'(GROUP BY|group by)(.*?)(order by|having|HAVING|ORDER BY|;|$)', sql_query, re.IGNORECASE) target_table = "agent_bidding_history_detail_all" # 如果找到匹配项 if match: # 提取 GROUP BY 后面的字段,去掉多余的空格,并分割成列表 fields = match.group(2).strip().split(',') for field in fields: if field.strip() == '招标单位': target_table = 'agent_bidding_history_detail_by_ifb_new' break elif field.strip() == '中标单位': target_table = 'agent_bidding_history_detail_by_wtb_new' break elif field.strip() == '招标产品': target_table = 'agent_bidding_history_detail_by_prod_new' break sql_query = sql_query.replace('agent_bidding_history_detail_all', target_table) return sql_query # return [field.strip() for field in fields] # 去掉每个字段前后的空格 async def call_plugin(plugin_name: str, plugin_args: str): # plugin_args = plugin_args.replace("```json",'') # plugin_args = plugin_args.replace("```",'').rstrip("`\n") # if plugin_args.startswith('```'): # plugin_args=extract_json(plugin_args) # if plugin_args.startswith('json'): # plugin_args = plugin_args.replace("json",'') import time print("plugin_args", plugin_args) print("plugin_name", plugin_name) if plugin_args.startswith('```'): triple_match = re.search(r'```[^\n]*\n(.+?)```', plugin_args, re.DOTALL) if triple_match: plugin_args = triple_match.group(1) print("plugin_args_clean:", plugin_args) if plugin_name == "TenderResultSqlAgent": try: # try: # if plugin_args.startswith('```'): # triple_match = re.search(r'```[^\n]*\n(.+?)```', plugin_args, re.DOTALL) # if triple_match: # plugin_args = triple_match.group(1) # sql_to_execute = json.loads(plugin_args)['sql_code'] # else: # sql_to_execute = json.loads(plugin_args)['sql_code'] sql_to_execute = json.loads(plugin_args)['sql_code'] sql_to_execute = modify_sql(sql_to_execute) # except: # sql_to_execute = json.loads(plugin_args)['description'] # sql_to_execute = sql_to_execute.replace(';','') # if plugin_args.startswith('sql'): # sql_to_execute = plugin_args.rstrip('sql') # else: # sql_to_execute = plugin_args print(f"sql_to_execute:{sql_to_execute}") print("plugin_name", plugin_name) # a = time.time() res_tuples = await async_db.run(sql_to_execute) print('SQL Time Cost:', time.time() - a) result, success = res_tuples if success and len(result) == 0: return '暂未查询到相关信息,可能是字段错误,请尝试调整查询条件,例如where条件中把“招标单位”改为“中标单位”', plugin_args, False return f"```json\n{result}\n```", plugin_args, success except: return f"ERROR:{traceback.format_exc()}", plugin_args, False elif plugin_name == 'search_ES': print(f"es plugin args: {plugin_args}") if plugin_args.startswith('```'): plugin_args = re.sub(r'(^```(json)?\n|\n```$)', '', plugin_args.strip()) print("plugin_args", plugin_args) try: plugin_args = json.loads(plugin_args) if isinstance(plugin_args, list): plugin_args = plugin_args[0] org_query = '' if 'organization_name' in plugin_args: org_query = plugin_args['organization_name'] elif "name" in plugin_args and plugin_args['name'] == "organization_name": org_query = plugin_args['value'] body, _ = gen_company_search_dsl(org_query) # print(f'es to execute: {body}') result = await es.search(body=body, index='company_basic_info_index_new') # print(f'es result: {result}') names = [] if result["hits"]['hits'][0]['_source']['companyname'] == org_query: return json.dumps({"ExactMatch": org_query}, ensure_ascii=False), plugin_args, True for h in result["hits"]['hits']: names.append(h['_source']['companyname']) result = json.dumps({"FuzzyMatch": names}, ensure_ascii=False) return f"```json\n{result}\n```", plugin_args, True except: return f"ERROR:{traceback.format_exc()}", plugin_args, False elif plugin_name == 'generate_table': try: # plot_args = json.loads(plugin_args.replace("\'","\"")) # plot_args['series']['data'] = sorted(plot_args['series']['data'],key=lambda x: x[0]) return plugin_args.replace("\'", "\""), plugin_args, True except: return f"ERROR:{traceback.format_exc()}", plugin_args, False elif plugin_name == 'generate_chart': try: print(f'code args: {plugin_args}') code_result = await code_interpreter(plugin_args) triple_match = re.search(r'```[^\n]*\n(.+?)```', code_result, re.DOTALL) # print(f'code results: {plugin_args}') if triple_match: text = triple_match.group(1) if code_result.startswith("error:"): return f"\n{text}", plugin_args, False else: print('code intere succ!', simplify_chart(text)) return f"\n{simplify_chart(text)}\n", plugin_args, True return code_result, plugin_args, False except: return f"ERROR:{traceback.format_exc()}", plugin_args, False elif plugin_name == 'recall_words': try: client = httpclient.InferenceServerClient(url="10.10.0.11:20100", verbose=False, network_timeout=3000) input_tensors = [httpclient.InferInput("source", [1, 1], "BYTES")] # httpclient.InferInput("product_list", [1,1], "BYTES"), source = np.array([['bid_prods'.encode('utf-8')]], dtype=np.object_) input_tensors[0].set_data_from_numpy(source) words = json.loads(plugin_args)['mention'] if isinstance(words, str): text = np.array([[words.encode('utf-8')]], dtype=np.object_) input_tensors.append(httpclient.InferInput("product_list", [1, 1], "BYTES")) input_tensors[1].set_data_from_numpy(text) words = [words] elif isinstance(words, list): text = np.array([[t.encode('utf-8')] for t in words], dtype=np.object_) input_tensors.append(httpclient.InferInput("product_list", [len(words), 1], "BYTES")) input_tensors[1].set_data_from_numpy(text) else: return 'plugin_args ERROR', plugin_args, False outputs = [ httpclient.InferRequestedOutput("output_products"), ] results = client.infer(model_name="product_sim_recall", inputs=input_tensors, outputs=outputs) output_data = results.as_numpy("output_products") return '|'.join(list(set([data.decode('utf-8') for data in output_data] + words))), plugin_args, True except: return f"ERROR:{traceback.format_exc()}", plugin_args, False else: raise NotImplementedError tools_list = [ { 'name_for_human': '查询招投标数据库', 'name_for_model': 'TenderResultSqlAgent', 'description_for_model': """ 当需要连接MySQL数据库并执行一段sql时,请使用此功能。 """ + ' Format the arguments as a JSON object.', 'parameters': [{'name': 'sql_code', 'type': 'string', 'description': '合法的MySQL查询语言。不接受【select *】,必须使用【select xxx,yyy】'}] }, { 'name_for_human': '获取准确的公司或单位名称', 'name_for_model': 'search_ES', 'description_for_model': '当需要获取公司的具体名称时,请使用此API。输入需要确认的公司或单位名称,返回与之相似的公司或单位名称' + ' Format the arguments as a JSON object.', # 'parameters': [{'name': 'dsl_json', 'type': 'string', 'description': 'Json格式的Elasticsearch的DSL查询语言,可以match的字段有:【companyname】'}] 'parameters': [{'name': 'organization_name', 'type': 'string', 'description': '需要确认的公司或单位名称'}] }, { 'name_for_human': '生成表格', 'name_for_model': 'generate_table', 'description_for_model': '[生成表格]用于将数据以表格方式呈现。输入HTML格式的表格数据,返回生成的表格url。' + ' Format the arguments as a JSON object.', 'parameters': [ { 'name': 'HTML_table', 'description': """HTML格式的表格数据,例如:
xxxnnnn
ssssdddd
""", 'required': True, 'schema': {'type': 'HTML'}, } ] }, { 'name_for_human': '生成图表', 'name_for_model': 'generate_chart', 'description_for_model': """[生成折线图/柱状图/散点图/饼图]是一个图像生成服务,用于生成折线图/柱状图/散点图。输入需求和数据,返回生成的图表的python代码。""", 'parameters': [{ 'name': 'code', 'type': 'string', 'description': '待执行的pyecharts代码。' }] }, { 'name_for_human': '语义召回', 'name_for_model': 'recall_words', 'description_for_model': """[语义召回]可以返回与输入文本相似的文本""", 'parameters': [{ 'name': 'mention', 'type': 'list of string', 'description': "原始文本中需要召回的文本,通常是一个或多个词,构成一个list:['xxx','yyy']" }] } ] if __name__ == "__main__": code = """import json import pyecharts.options as opts from pyecharts.globals import ThemeType from pyecharts.charts import Line #数据:格式为Dict source = { '发布月份': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'], '招标次数': ['16', '19', '32', '20', '26', '14', '29', '22', '38', '21', '14', '6'] } #折线图 line = Line() line.add_xaxis(source['发布月份']) line.add_yaxis('招标次数', source['招标次数']) line.set_global_opts(title_opts=opts.TitleOpts(title='上海市青浦区徐泾镇人民政府招标公告发布频率'), xaxis_opts=opts.AxisOpts(name='月份'), yaxis_opts=opts.AxisOpts(name='招标次数')) option = line.dump_options() print(option) """ print('-----\n', code_interpreter(code)) # s = MySQLSearcher() # print(s.run("SELECT publishdate, COUNT(*) as bid_count FROM tmp_qwen_copilot_orgname_bidding_success WHERE bidding_org_name = '34部队' GROUP BY publishdate ORDER BY publishdate"))