config.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from parser import InternLMReActParser, ReActParser
  2. from models import LLM, HFModel, Qwen
  3. from prompt import InternLMReAct, LlamaReAct, QwenReAct
  4. react_prompt_map = {
  5. 'qwen': QwenReAct,
  6. 'lianqiai': LianqiaiReAct,
  7. 'llama': LlamaReAct,
  8. 'internlm': InternLMReAct,
  9. }
  10. react_parser_map = {
  11. 'qwen': ReActParser,
  12. 'lianqiai': ReActParser,
  13. 'llama': ReActParser,
  14. 'internlm': InternLMReActParser,
  15. }
  16. model_map = {
  17. 'qwen': Qwen,
  18. 'lianqiai': Qwen,
  19. 'llama': LLM,
  20. 'internlm': LLM,
  21. 'qwen-vl': HFModel
  22. }
  23. model_type_map = {
  24. 'qwen-14b-chat': 'qwen',
  25. 'qwen-1.8b-chat': 'qwen',
  26. 'qwen-7b-chat': 'qwen',
  27. 'llama-2-7b-chat': 'llama',
  28. 'llama-2-13b-chat': 'llama',
  29. 'codellama-7b-instruct': 'llama',
  30. 'codellama-13b-instruct': 'llama',
  31. 'internlm-7b-chat-1.1': 'internlm',
  32. 'internlm-20b-chat': 'internlm',
  33. 'qwen-vl-chat': 'qwen-vl',
  34. }
  35. model_path_map = {
  36. 'qwen-14b-chat': 'Qwen/Qwen-14B-Chat',
  37. 'qwen-7b-chat': 'Qwen/Qwen-7B-Chat',
  38. 'qwen-1.8b-chat': 'Qwen/Qwen-1.8B-chat',
  39. 'llama-2-7b-chat': 'meta-llama/Llama-2-7b-chat-hf',
  40. 'llama-2-13b-chat': 'meta-llama/Llama-2-13b-chat-hf',
  41. 'codellama-7b-instruct': 'codellama/CodeLlama-7b-Instruct-hf',
  42. 'codellama-13b-instruct': 'codellama/CodeLlama-13b-Instruct-hf',
  43. 'internlm-7b-chat-1.1': 'internlm/internlm-chat-7b-v1_1',
  44. 'internlm-20b-chat': 'internlm/internlm-chat-20b',
  45. 'qwen-vl-chat': 'Qwen/Qwen-VL-Chat',
  46. }
  47. def get_react_prompt(model_name, query, lang, upload_fname_list):
  48. react_prompt_cls = react_prompt_map.get(model_type_map[model_name], QwenReAct)
  49. return react_prompt_cls(query, lang, upload_fname_list)
  50. def get_react_parser(model_name):
  51. react_parser_cls = react_parser_map.get(model_type_map[model_name], ReActParser)
  52. return react_parser_cls()
  53. def get_model(model_name):
  54. model_path = model_path_map[model_name]
  55. model_cls = model_map.get(model_type_map[model_name], LLM)
  56. return model_cls(model_path)