test_deepseek.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # from llama_cpp import Llama
  2. # # llm = Llama(model_path="/mnt/nas/model/nlp/DeepSeek_GGUF/deepseek-coder-33b-instruct.Q5_K_M.gguf")
  3. # # input="""You are an AI programming assistant, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.
  4. # # ### Instruction:
  5. # # 写一个求编辑距离的python函数
  6. # # ### Response:
  7. # # """
  8. # # import time
  9. # # a = time.time()
  10. # # output = llm(input, max_tokens=512, echo=True,top_k=1)
  11. # # print(output,'\ntime:',time.time()-a)
  12. from transformers.generation import LogitsProcessor
  13. from typing import Tuple, List, Union, Iterable
  14. import numpy as np
  15. from transformers.generation.logits_process import LogitsProcessorList
  16. import torch
  17. class StopWordsLogitsProcessor(LogitsProcessor):
  18. """
  19. :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
  20. Args:
  21. stop_words_ids (:obj:`List[List[int]]`):
  22. List of list of token ids of stop ids. In order to get the tokens of the words
  23. that should not appear in the generated text, use :obj:`tokenizer(bad_word,
  24. add_prefix_space=True).input_ids`.
  25. eos_token_id (:obj:`int`):
  26. The id of the `end-of-sequence` token.
  27. """
  28. def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
  29. if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
  30. raise ValueError(
  31. f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
  32. )
  33. if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
  34. raise ValueError(
  35. f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
  36. )
  37. if any(
  38. any(
  39. (not isinstance(token_id, (int, np.integer)) or token_id < 0)
  40. for token_id in stop_word_ids
  41. )
  42. for stop_word_ids in stop_words_ids
  43. ):
  44. raise ValueError(
  45. f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
  46. )
  47. self.stop_words_ids = list(
  48. filter(
  49. lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
  50. )
  51. )
  52. self.eos_token_id = eos_token_id
  53. for stop_token_seq in self.stop_words_ids:
  54. assert (
  55. len(stop_token_seq) > 0
  56. ), "Stop words token sequences {} cannot have an empty list".format(
  57. stop_words_ids
  58. )
  59. def __call__(
  60. self, input_ids: torch.LongTensor, scores: torch.FloatTensor
  61. ) -> torch.FloatTensor:
  62. stopped_samples = self._calc_stopped_samples(input_ids)
  63. for i, should_stop in enumerate(stopped_samples):
  64. if should_stop:
  65. scores[i, self.eos_token_id] = float(2**15)
  66. return scores
  67. def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
  68. if len(tokens) == 0:
  69. # if bad word tokens is just one token always ban it
  70. return True
  71. elif len(tokens) > len(prev_tokens):
  72. # if bad word tokens are longer then prev input_ids they can't be equal
  73. return False
  74. elif prev_tokens[-len(tokens) :].tolist() == tokens:
  75. # if tokens match
  76. return True
  77. else:
  78. return False
  79. def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
  80. stopped_samples = []
  81. for prev_input_ids_slice in prev_input_ids:
  82. match = False
  83. for stop_token_seq in self.stop_words_ids:
  84. if self._tokens_match(prev_input_ids_slice, stop_token_seq):
  85. # if tokens do not match continue
  86. match = True
  87. break
  88. stopped_samples.append(match)
  89. return stopped_samples
  90. import torch
  91. from threading import Thread
  92. from transformers import AutoTokenizer, AutoModelForCausalLM,TextIteratorStreamer,TextStreamer
  93. tokenizer = AutoTokenizer.from_pretrained("/mnt/nas/model/nlp/Deepseek", trust_remote_code=True)
  94. model = AutoModelForCausalLM.from_pretrained(
  95. "/mnt/nas/model/nlp/Deepseek",
  96. device_map="cuda:2",
  97. trust_remote_code=True,torch_dtype=torch.float16
  98. ).eval()
  99. streamer = TextIteratorStreamer(tokenizer,skip_prompt=True,decode_kwargs={'skip_special_tokens':True,'errors':'ignore'})
  100. # def stream_generator(input_ids):
  101. # outputs = []
  102. # for token in NewGenerationMixin.generate(
  103. # input_ids,
  104. # return_dict_in_generate=False,
  105. # generation_config=stream_config,
  106. # seed=-1):
  107. # outputs.append(token.item())
  108. # yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
  109. messages=[
  110. { 'role': 'user', 'content': """你是一个MySQL专家,当前需要根据用户问题和上下文,生成语法正确的MySQL查询语句。'
  111. #数据库表的表名和表结构如下:
  112. `agent_bidding_history_detail`(
  113. `标题`,
  114. `行业`,
  115. `发布年份`,
  116. `发布月份`,
  117. `发布日`,
  118. `发布日期`,
  119. `省`,
  120. `市`,
  121. `区`,
  122. `招标单位`,
  123. `中标单位`,
  124. `代理单位`,
  125. `中标实际金额`,
  126. `招标预算金额`,
  127. `招标产品`,
  128. `招标类型`,# 有两种:[中标结果]和[招标公告]。
  129. )
  130. 有几个注意事项:
  131. 当涉及到时间时,尽量用CURDATE(),YEAR(),MONTH()等函数
  132. 请仔细区分"去年","今年","N年前"等时间关键词。
  133. 当涉及到地理位置时,请注意省市区的区分。
  134. 以下是可供参考的SQL写法(仅供参考,也可自由发挥):
  135. ```
  136. 0 请分析下这两年product的中标金额情况:SELECT 发布年份,SUM(中标实际金额) as 中标金额 FROM agent_bidding_history_detail WHERE 招标类型 = '中标结果' and (招标产品 like '%product%' or 标题 like '%product%') and 发布年份 BETWEEN YEAR(CURDATE())-2 AND YEAR(CURDATE()) GROUP BY 发布年份
  137. 1 company_A和company_B的一些合作记录发给我:SELECT 标题,发布日期,中标实际金额 FROM agent_bidding_history_detail WHERE 招标类型 = '中标结果' and ((招标单位='company_A' and 中标单位='company_B') or (招标单位='company_B' and 中标单位='company_A')) LIMIT 20
  138. 2 companyname【过去一年】的招标中,中标单位分布情况如何?:SELECT 中标单位,COUNT(1) as 中标个数, SUM(中标实际金额) as 中标金额 FROM agent_bidding_history_detail WHERE 招标类型 = '中标结果' and 招标单位 LIKE '%companyname%' and 发布年份 BETWEEN YEAR(CURDATE())-1 AND YEAR(CURDATE()) GROUP BY 中标单位
  139. 3 organization今年的招标情况如何?:SELECT COUNT(1) as 招标次数, SUM(招标预算金额) as 招标预算,GROUP_CONCAT( `招标产品`,',') as 招标产品 FROM agent_bidding_history_detail WHERE 招标单位 LIKE '%organization%' AND 招标类型 = '招标公告' AND 发布年份 = YEAR(CURDATE())
  140. 4 对比下organization去年和今年每个月的招标数量:SELECT 发布年份,发布月份,COUNT(1) as 招标次数 FROM agent_bidding_history_detail WHERE 招标单位 LIKE '%organization_name%' AND 招标类型 = '招标公告' GROUP BY 发布年份,发布月份 ORDER BY 发布年份,发布月份
  141. ```
  142. 下面是API列表,可以选择有助于完成用户需求的一个或多个API:
  143. #API列表
  144. TenderResultSqlAgent: Call this tool to interact with the 查询招投标数据库 API. What is the 查询招投标数据库 API useful for?
  145. 当需要连接MySQL数据库并执行一段sql时,请使用此功能。
  146. Format the arguments as a JSON object. Parameters: [{"name": "sql_code", "type": "string", "description": "合法的MySQL查询语言。不接受【select *】,必须使用【select xxx,yyy】"}]
  147. 请依据以上可选择的API,制定计划完成用户需求,按照如下格式返回:
  148. Question: 用户需求。
  149. Thought: 生成计划的原因。
  150. Action: 当前需要使用的API,必须包含在[TenderResultSqlAgent] 中。注意这里只需要放API的名字(name_for_model),不需要额外的信息
  151. Action Input: 当前API的输入参数。注意这里只需要放JSON格式的API的参数(parameters),不需要额外的信息
  152. Observation: API的输出。
  153. ... (以上 /Thought/Action/Action Input/Observation的过程可以重复多次,直到产生预期的效果)。
  154. Begin!
  155. Question: 用户的原始Question为:最近有什么轮胎的招标需求
  156. 对Question进行分析,我认为需要用以下执行计划完成用户的查询需求:[{"action_name": "TenderResultSqlAgent", "instruction": "查询最近的轮胎招标信息"}, {"action_name": "summary", "instruction": "对查询结果进行总结,回答用户的问题"}]
  157. 已经执行结束的Action如下:
  158. 需要执行的Action如下:
  159. Instruction: 查询最近的轮胎招标信息"""}
  160. ]
  161. stop_words = []
  162. if "\nObservation:" not in stop_words:
  163. stop_words.append("\nObservation:")
  164. stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
  165. print('stop_words_ids:',stop_words_ids)
  166. if stop_words_ids is not None:
  167. stop_words_logits_processor = StopWordsLogitsProcessor(
  168. stop_words_ids=stop_words_ids,
  169. eos_token_id=32021,
  170. )
  171. logits_processor = LogitsProcessorList([stop_words_logits_processor])
  172. inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
  173. generation_kwargs = {}
  174. generation_kwargs['inputs']=inputs
  175. generation_kwargs['max_new_tokens']=512
  176. generation_kwargs['do_sample']=False
  177. generation_kwargs['num_return_sequences']=1
  178. # generation_kwargs['logits_processor']=logits_processor
  179. # 32021 is the id of <|EOT|> token
  180. stop_words_ids[0].append(32021)
  181. # generation_kwargs['eos_token_id']=stop_words_ids[0]
  182. generation_kwargs['streamer']=streamer
  183. thread = Thread(target=model.generate, kwargs=generation_kwargs)
  184. thread.start()
  185. generated_text = ''
  186. for new_text in streamer:
  187. if len(new_text)==0:
  188. continue
  189. if new_text != '<|EOT|>':
  190. generated_text+=new_text
  191. if 'Observation:' in generated_text:
  192. generated_text = generated_text.split('Observation:')[0]
  193. print(new_text,end='',flush=True)
  194. print('-----------')
  195. print(generated_text)
  196. # response = """
  197. # Action: TenderResultSqlAgent
  198. # Action Input: {"sql_code": "SELECT 标题, 发布日期, 中标实际金额 FROM agent_bidding_history_detail WHERE 招标类型 = '招标公告' AND 招标产品 LIKE '%轮胎%' ORDER BY 发布日期 DESC LIMIT 10"}
  199. # Observation:
  200. # ```
  201. # [
  202. # {"标题": "2022年轮胎招标", "发布日期": "2022-01-01", "中标实际金额": 1000000},
  203. # {"标题": "2021年轮胎招标", "发布日期": "2021-12-31", "中标实际金额": 900000},
  204. # {"标题": "2021年轮胎招标", "发布日期": "2021-12-30", "中标实际金额": 800000},
  205. # ...
  206. # ]
  207. # ```
  208. # Instruction: 对查询结果进行总结,回答用户的问题
  209. # Observation: 最近的轮胎招标信息包括:2022年1月1日的2022年轮胎招标,中标金额为1000000元;2021年12月31日的2021年轮胎招标,中标金额为900000元;2021年12月30日的2021年轮胎招标,中标金额为800000元。
  210. # 以上就是我根据用户需求生成的执行计划。
  211. # """
  212. # def parse_response_func(response):
  213. # func_name, func_args = "", ""
  214. # i = response.find("Action:")
  215. # j = response.find("\nAction Input:")
  216. # k = response.find("\nObservation:")
  217. # print(i,j,k)
  218. # if 0 <= i < j: # If the text has `Action` and `Action input`,
  219. # func_name = response[i + len("Action:") : j].strip()
  220. # func_args = response[j + len("\nAction Input:") : k].strip()
  221. # if func_name:
  222. # choice_data = {'role':"assistant","content":response[:i],
  223. # "function_call":{"name": func_name, "arguments": func_args}
  224. # }
  225. # return choice_data
  226. # print(parse_response_func(response=response))