qwen.py 810 B

12345678910111213141516171819202122
  1. import torch
  2. from models.base import HFModel
  3. class Qwen(HFModel):
  4. def __init__(self, model_path):
  5. super().__init__(model_path)
  6. def generate(self, input_text, stop_words=[]):
  7. im_end = '<|im_end|>'
  8. if im_end not in stop_words:
  9. stop_words = stop_words + [im_end]
  10. stop_words_ids = [self.tokenizer.encode(w) for w in stop_words]
  11. input_ids = torch.tensor([self.tokenizer.encode(input_text)]).to(self.model.device)
  12. output = self.model.generate(input_ids, stop_words_ids=stop_words_ids)
  13. output = output.tolist()[0]
  14. output = self.tokenizer.decode(output, errors='ignore')
  15. assert output.startswith(input_text)
  16. output = output[len(input_text):].replace('<|endoftext|>', '').replace(im_end, '')
  17. return output