123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556 |
- import asyncpg
- from qwen_agent.config import db_config
- from qwen_agent.tools.image_gen import image_gen # NOQA
- import json
- from langchain.utilities import SQLDatabase
- import pymysql
- from elasticsearch import AsyncElasticsearch, Elasticsearch
- import traceback
- from qwen_agent.tools.query_understanding import gen_company_search_dsl
- import pymysql
- from datetime import datetime
- from qwen_agent.tools.code_interpreter import code_interpreter
- import re
- import asyncio
- import aiomysql
- import tritonclient.http as httpclient
- import numpy as np
- from qwen_agent.utils.util import toJson
- pymysql.install_as_MySQLdb()
- def simplify_chart(option):
- result = {'series': [], 'xAxis': [], 'yAxis': [], 'title': [], 'legend': []}
- json_data = json.loads(option)
- for key_level1 in ['series', 'xAxis', 'yAxis', 'title', 'legend']:
- for item in json_data.get(key_level1, []):
- tmp = {}
- for key, value in item.items():
- if key in ['type', 'data', 'text', 'name', 'label']:
- tmp[key] = value
- result[key_level1].append(tmp)
- # 对多个series 的进行处理 (折线图或者柱状图的情况下)
- if len(result['series']) > 1 and result['series'][0]['type'] in ('line', 'bar'):
- i = 0
- selected = {}
- for series in result['series']:
- selected[series['name']] = bool(i == 0)
- i += 1
- result['legend'] = dict(selected=selected)
- return json.dumps(result, ensure_ascii=False)
- def format_json(data):
- # 遍历字典的每一个键值对
- for key in data:
- if '金额' in key or '预算' in key:
- # 检查是否所有元素都可以转换为浮点数
- try:
- values = [float(value) if value != 'None' else 0 for value in data[key]]
- # 如果可以,那么对每个元素进行单位转换
- max_value = max(values)
- if max_value >= 1e8:
- data[key] = [f"{value / 1e8:.2f}亿" for value in values]
- elif max_value >= 1e4:
- data[key] = [f"{value / 1e4:.2f}万" for value in values]
- else:
- data[key] = [f"{value}" for value in values]
- except ValueError:
- # 如果不能,那么保留原来的值
- pass
- return data
- class MySQLSearcher():
- def __init__(self, db_name='lianqiai_db') -> None:
- # 建立数据库连接
- self.db_name = db_name
- self.connect()
- # print(f'db name: {self.db_name}')
- def connect(self):
- self.db = pymysql.connect(
- host='xx.aliyuncs.com',
- user='xxx',
- password='xxx',
- db=self.db_name,
- port=3306
- )
- self.cursor = self.db.cursor()
- def format_result(self):
- result = self.cursor.fetchmany(20)
- if len(result) == 0:
- return ''
- headers = [column[0] for column in self.cursor.description]
- # print(f"headers:{headers}")
- # 使用tabulate创建Markdown格式的表格
- # markdown_table = tabulate(result, headers, tablefmt="pipe")
- json_data = {}
- for i, header in enumerate(headers):
- json_data[header] = [f"{row[i]}" if len(f"{row[i]}") > 0 else '未知' for row in result]
- json_data = format_json(json_data)
- return json.dumps(json_data, ensure_ascii=False)
- def run(self, command):
- try:
- self.cursor.execute(command)
- result = self.format_result()
- except Exception as e:
- try:
- self.connect()
- self.cursor = self.db.cursor()
- self.cursor.execute(command)
- result = self.format_result()
- except Exception as ex:
- return f"ERROR:{traceback.format_exc()[-200:]}", False
- return result, True
- def _close(self):
- self.cursor.close()
- self.db.commit()
- self.db.close()
- class AsyncMySQLSearcher():
- def __init__(self, db_name='lianqiai_db') -> None:
- # 建立数据库连接
- self.db_name = db_name
- self.db = None
- self.pool = None
- # await self.connect()
- # print(f'db name: {self.db_name}')
- async def connect(self):
- self.db = await aiomysql.connect(
- # host='rm-bp13i5ci7o9ev1241ho.mysql.rds.aliyuncs.com',
- host='xxxx.ads.aliyuncs.com', # analysisDB
- user='xxxx',
- password='xxxxx',
- db=self.db_name,
- port=3306
- )
- # self.db = await aiomysql.connect(
- # host='10.10.0.10',
- # user='root',
- # password='Lianqiai',
- # db='lianqi_db',
- # port=13306
- # )
- self.cursor = await self.db.cursor()
- async def register(self):
- '''
- 初始化,获取数据库连接池
- :return:
- '''
- try:
- print("start to connect db!")
- self.pool = await aiomysql.create_pool(host='amv-bp1sk343446b8u0d100001808o.ads.aliyuncs.com', port=3306,
- user='lianqi_admin', password='(lianqi666666)',
- db='lianqiai_db')
- print("succeed to connect db!")
- except asyncio.CancelledError:
- raise asyncio.CancelledError
- except Exception as ex:
- print("mysql数据库连接失败:{}".format(ex.args[0]))
- async def run(self, sql):
- if not self.pool:
- await self.register()
- '''
- 查询, 一般流程是首先获取连接,光标,获取数据之后,则需要释放连接
- :param pool:
- :return:
- '''
- # conn, cur = await self.getCurosr()
- try:
- async with self.pool.acquire() as conn:
- async with conn.cursor() as cur:
- await cur.execute(sql)
- result = await cur.fetchmany(30)
- return await self.format_result(result, cur), True
- except Exception as e:
- return f"ERROR:{traceback.format_exc()[-200:]}", False
- async def format_result(self, result, cursor):
- if len(result) == 0:
- return ''
- headers = [column[0] for column in cursor.description]
- json_data = {}
- for i, header in enumerate(headers):
- json_data[header] = [f"{row[i]}" if len(f"{row[i]}") > 0 else '未知' for row in result]
- json_data = format_json(json_data)
- # return json.dumps(json_data,ensure_ascii=False)
- row_result = []
- # 将数据转换为列表
- for i in range(len(json_data[headers[0]])):
- tmp_dict = {}
- for head in headers:
- tmp_dict[head] = json_data[head][i]
- row_result.append(tmp_dict)
- return json.dumps(row_result, ensure_ascii=False)
- def _close(self):
- self.cursor.close()
- self.db.commit()
- self.db.close()
- class AsyncPGSearcher:
- def __init__(self, db_name='pg') -> None:
- # 建立数据库连接
- self.db_name = db_name
- self.db = None
- self.pool = None
- # await self.connect()
- # print(f'db name: {self.db_name}')
- async def connect(self):
- pg_config = db_config.db_list.get(self.db_name)
- print("pg_config:" %pg_config);
- self.db = await asyncpg.connect(
- host=pg_config.get("host"), # analysisDB
- user=pg_config.get("user"),
- password=pg_config.get("password"),
- database=pg_config.get("database"),
- port=pg_config.get("port")
- )
- self.cursor = await self.db.cursor()
- async def register(self):
- '''
- 初始化,获取数据库连接池
- :return:
- '''
- try:
- print("start to connect db!")
- pg_config = db_config.db_list.get(self.db_name)
- self.pool = await asyncpg.create_pool(host=pg_config.get("host"),
- user=pg_config.get("user"),
- password=pg_config.get("password"),
- database=pg_config.get("database"),
- port=pg_config.get("port"))
- print("succeed to connect db!")
- except asyncio.CancelledError:
- raise asyncio.CancelledError
- except Exception as ex:
- print("pg数据库连接失败:{}".format(ex.args[0]))
- async def run(self, sql):
- if not self.pool:
- await self.register()
- '''
- 查询, 一般流程是首先获取连接,光标,获取数据之后,则需要释放连接
- :param pool:
- :return:
- '''
- # conn, cur = await self.getCurosr()
- try:
- result = await self.pool.fetch(sql)
- return await self.format_result(result), True
- except Exception as e:
- # return f"ERROR:{json.dumps(e, ensure_ascii=False)}", False
- return f"ERROR:{traceback.format_exc()[-200:]}", False
- async def format_result(self, result):
- if len(result) == 0:
- return ''
- result_json_obj = []
- for record in result:
- record_dict = dict(record)
- result_json_obj.append(record_dict)
- json_result = toJson(result_json_obj)
- return toJson(result_json_obj)
- def _close(self):
- self.cursor.close()
- self.db.commit()
- self.db.close()
- def format_json(data):
- # 遍历字典的每一个键值对
- for key in data:
- if '金额' in key or '预算' in key:
- # 检查是否所有元素都可以转换为浮点数
- try:
- values = [float(value) if value != 'None' else 0 for value in data[key]]
- # 如果可以,那么对每个元素进行单位转换
- max_value = max(values)
- if max_value >= 1e8:
- data[key] = [f"{value / 1e8:.2f}亿" for value in values]
- elif max_value >= 1e4:
- data[key] = [f"{value / 1e4:.2f}万" for value in values]
- else:
- data[key] = [f"{value}" for value in values]
- except ValueError:
- # 如果不能,那么保留原来的值
- pass
- return data
- # db = MySQLSearcher()
- # es = AsyncElasticsearch([{'host':'xxxxx.public.elasticsearch.aliyuncs.com', 'port': 9200,'scheme': "http"}],
- # http_auth=('xxx','xxxx!'),timeout=3600)
- # async_db = AsyncMySQLSearcher()
- # asyncio.run(async_db.register())
- async_db = AsyncPGSearcher('pg')
- async_xzdb = AsyncPGSearcher('xzpg')
- def modify_sql(sql_query):
- # 使用正则表达式寻找 GROUP BY 和其后面的内容,直到遇到 HAVING, ORDER BY 或者字符串结束
- match = re.search(r'(GROUP BY|group by)(.*?)(order by|having|HAVING|ORDER BY|;|$)', sql_query, re.IGNORECASE)
- target_table = "agent_bidding_history_detail_all"
- # 如果找到匹配项
- if match:
- # 提取 GROUP BY 后面的字段,去掉多余的空格,并分割成列表
- fields = match.group(2).strip().split(',')
- for field in fields:
- if field.strip() == '招标单位':
- target_table = 'agent_bidding_history_detail_by_ifb_new'
- break
- elif field.strip() == '中标单位':
- target_table = 'agent_bidding_history_detail_by_wtb_new'
- break
- elif field.strip() == '招标产品':
- target_table = 'agent_bidding_history_detail_by_prod_new'
- break
- sql_query = sql_query.replace('agent_bidding_history_detail_all', target_table)
- return sql_query
- # return [field.strip() for field in fields] # 去掉每个字段前后的空格
- async def call_plugin(plugin_name: str, plugin_args: str):
- # plugin_args = plugin_args.replace("```json",'')
- # plugin_args = plugin_args.replace("```",'').rstrip("`\n")
- # if plugin_args.startswith('```'):
- # plugin_args=extract_json(plugin_args)
- # if plugin_args.startswith('json'):
- # plugin_args = plugin_args.replace("json",'')
- import time
- print("plugin_args", plugin_args)
- print("plugin_name", plugin_name)
- if plugin_args.startswith('```'):
- triple_match = re.search(r'```[^\n]*\n(.+?)```', plugin_args, re.DOTALL)
- if triple_match:
- plugin_args = triple_match.group(1)
- print("plugin_args_clean:", plugin_args)
- if plugin_name == "TenderResultSqlAgent":
- try:
- # try:
- # if plugin_args.startswith('```'):
- # triple_match = re.search(r'```[^\n]*\n(.+?)```', plugin_args, re.DOTALL)
- # if triple_match:
- # plugin_args = triple_match.group(1)
- # sql_to_execute = json.loads(plugin_args)['sql_code']
- # else:
- # sql_to_execute = json.loads(plugin_args)['sql_code']
- sql_to_execute = json.loads(plugin_args)['sql_code']
- sql_to_execute = modify_sql(sql_to_execute)
- # except:
- # sql_to_execute = json.loads(plugin_args)['description']
- # sql_to_execute = sql_to_execute.replace(';','')
- # if plugin_args.startswith('sql'):
- # sql_to_execute = plugin_args.rstrip('sql')
- # else:
- # sql_to_execute = plugin_args
- print(f"sql_to_execute:{sql_to_execute}")
- print("plugin_name", plugin_name)
- #
- a = time.time()
- res_tuples = await async_db.run(sql_to_execute)
- print('SQL Time Cost:', time.time() - a)
- result, success = res_tuples
- if success and len(result) == 0:
- return '暂未查询到相关信息,可能是字段错误,请尝试调整查询条件,例如where条件中把“招标单位”改为“中标单位”', plugin_args, False
- return f"```json\n{result}\n```", plugin_args, success
- except:
- return f"ERROR:{traceback.format_exc()}", plugin_args, False
- elif plugin_name == 'search_ES':
- print(f"es plugin args: {plugin_args}")
- if plugin_args.startswith('```'):
- plugin_args = re.sub(r'(^```(json)?\n|\n```$)', '', plugin_args.strip())
- print("plugin_args", plugin_args)
- try:
- plugin_args = json.loads(plugin_args)
- if isinstance(plugin_args, list):
- plugin_args = plugin_args[0]
- org_query = ''
- if 'organization_name' in plugin_args:
- org_query = plugin_args['organization_name']
- elif "name" in plugin_args and plugin_args['name'] == "organization_name":
- org_query = plugin_args['value']
- body, _ = gen_company_search_dsl(org_query)
- # print(f'es to execute: {body}')
- result = await es.search(body=body, index='company_basic_info_index_new')
- # print(f'es result: {result}')
- names = []
- if result["hits"]['hits'][0]['_source']['companyname'] == org_query:
- return json.dumps({"ExactMatch": org_query}, ensure_ascii=False), plugin_args, True
- for h in result["hits"]['hits']:
- names.append(h['_source']['companyname'])
- result = json.dumps({"FuzzyMatch": names}, ensure_ascii=False)
- return f"```json\n{result}\n```", plugin_args, True
- except:
- return f"ERROR:{traceback.format_exc()}", plugin_args, False
- elif plugin_name == 'generate_table':
- try:
- # plot_args = json.loads(plugin_args.replace("\'","\""))
- # plot_args['series']['data'] = sorted(plot_args['series']['data'],key=lambda x: x[0])
- return plugin_args.replace("\'", "\""), plugin_args, True
- except:
- return f"ERROR:{traceback.format_exc()}", plugin_args, False
- elif plugin_name == 'generate_chart':
- try:
- print(f'code args: {plugin_args}')
- code_result = await code_interpreter(plugin_args)
- triple_match = re.search(r'```[^\n]*\n(.+?)```', code_result, re.DOTALL)
- # print(f'code results: {plugin_args}')
- if triple_match:
- text = triple_match.group(1)
- if code_result.startswith("error:"):
- return f"\n{text}", plugin_args, False
- else:
- print('code intere succ!', simplify_chart(text))
- return f"\n{simplify_chart(text)}\n", plugin_args, True
- return code_result, plugin_args, False
- except:
- return f"ERROR:{traceback.format_exc()}", plugin_args, False
- elif plugin_name == 'recall_words':
- try:
- client = httpclient.InferenceServerClient(url="10.10.0.11:20100", verbose=False, network_timeout=3000)
- input_tensors = [httpclient.InferInput("source", [1, 1],
- "BYTES")] # httpclient.InferInput("product_list", [1,1], "BYTES"),
- source = np.array([['bid_prods'.encode('utf-8')]], dtype=np.object_)
- input_tensors[0].set_data_from_numpy(source)
- words = json.loads(plugin_args)['mention']
- if isinstance(words, str):
- text = np.array([[words.encode('utf-8')]], dtype=np.object_)
- input_tensors.append(httpclient.InferInput("product_list", [1, 1], "BYTES"))
- input_tensors[1].set_data_from_numpy(text)
- words = [words]
- elif isinstance(words, list):
- text = np.array([[t.encode('utf-8')] for t in words], dtype=np.object_)
- input_tensors.append(httpclient.InferInput("product_list", [len(words), 1], "BYTES"))
- input_tensors[1].set_data_from_numpy(text)
- else:
- return 'plugin_args ERROR', plugin_args, False
- outputs = [
- httpclient.InferRequestedOutput("output_products"),
- ]
- results = client.infer(model_name="product_sim_recall", inputs=input_tensors, outputs=outputs)
- output_data = results.as_numpy("output_products")
- return '|'.join(list(set([data.decode('utf-8') for data in output_data] + words))), plugin_args, True
- except:
- return f"ERROR:{traceback.format_exc()}", plugin_args, False
- else:
- raise NotImplementedError
- tools_list = [
- {
- 'name_for_human': '查询招投标数据库',
- 'name_for_model': 'TenderResultSqlAgent',
- 'description_for_model': """
- 当需要连接MySQL数据库并执行一段sql时,请使用此功能。
- """
- + ' Format the arguments as a JSON object.',
- 'parameters': [{'name': 'sql_code', 'type': 'string',
- 'description': '合法的MySQL查询语言。不接受【select *】,必须使用【select xxx,yyy】'}]
- },
- {
- 'name_for_human': '获取准确的公司或单位名称',
- 'name_for_model': 'search_ES',
- 'description_for_model': '当需要获取公司的具体名称时,请使用此API。输入需要确认的公司或单位名称,返回与之相似的公司或单位名称'
- + ' Format the arguments as a JSON object.',
- # 'parameters': [{'name': 'dsl_json', 'type': 'string', 'description': 'Json格式的Elasticsearch的DSL查询语言,可以match的字段有:【companyname】'}]
- 'parameters': [{'name': 'organization_name', 'type': 'string', 'description': '需要确认的公司或单位名称'}]
- },
- {
- 'name_for_human': '生成表格',
- 'name_for_model': 'generate_table',
- 'description_for_model': '[生成表格]用于将数据以表格方式呈现。输入HTML格式的表格数据,返回生成的表格url。'
- + ' Format the arguments as a JSON object.',
- 'parameters': [
- {
- 'name': 'HTML_table',
- 'description': """HTML格式的表格数据,例如:<table><tr><th>xxx</th><th>nnnn</th></tr><tr><td>ssss</td><td>dddd</td></tr></table>""",
- 'required': True,
- 'schema': {'type': 'HTML'},
- }
- ]
- },
- {
- 'name_for_human': '生成图表',
- 'name_for_model': 'generate_chart',
- 'description_for_model': """[生成折线图/柱状图/散点图/饼图]是一个图像生成服务,用于生成折线图/柱状图/散点图。输入需求和数据,返回生成的图表的python代码。""",
- 'parameters': [{
- 'name': 'code',
- 'type': 'string',
- 'description': '待执行的pyecharts代码。'
- }]
- },
- {
- 'name_for_human': '语义召回',
- 'name_for_model': 'recall_words',
- 'description_for_model': """[语义召回]可以返回与输入文本相似的文本""",
- 'parameters': [{
- 'name': 'mention',
- 'type': 'list of string',
- 'description': "原始文本中需要召回的文本,通常是一个或多个词,构成一个list:['xxx','yyy']"
- }]
- }
- ]
- if __name__ == "__main__":
- code = """import json
- import pyecharts.options as opts
- from pyecharts.globals import ThemeType
- from pyecharts.charts import Line
- #数据:格式为Dict
- source = {
- '发布月份': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'],
- '招标次数': ['16', '19', '32', '20', '26', '14', '29', '22', '38', '21', '14', '6']
- }
- #折线图
- line = Line()
- line.add_xaxis(source['发布月份'])
- line.add_yaxis('招标次数', source['招标次数'])
- line.set_global_opts(title_opts=opts.TitleOpts(title='上海市青浦区徐泾镇人民政府招标公告发布频率'),
- xaxis_opts=opts.AxisOpts(name='月份'),
- yaxis_opts=opts.AxisOpts(name='招标次数'))
- option = line.dump_options()
- print(option)
- """
- print('-----\n', code_interpreter(code))
- # s = MySQLSearcher()
- # print(s.run("SELECT publishdate, COUNT(*) as bid_count FROM tmp_qwen_copilot_orgname_bidding_success WHERE bidding_org_name = '34部队' GROUP BY publishdate ORDER BY publishdate"))
|