llm.py 853 B

12345678910111213141516171819202122232425
  1. import torch
  2. from models.base import HFModel
  3. class LLM(HFModel):
  4. def __init__(self, model_path):
  5. super().__init__(model_path)
  6. def generate(self, input_text, stop_words=[], max_new_tokens=512):
  7. if isinstance(input_text, str):
  8. input_text = [input_text]
  9. input_ids = self.tokenizer(input_text)['input_ids']
  10. input_ids = torch.tensor(input_ids, device=self.model.device)
  11. gen_kwargs = {'max_new_tokens': max_new_tokens, 'do_sample': False}
  12. outputs = self.model.generate(input_ids, **gen_kwargs)
  13. s = outputs[0][input_ids.shape[1]:]
  14. output = self.tokenizer.decode(s, skip_special_tokens=True)
  15. for stop_str in stop_words:
  16. idx = output.find(stop_str)
  17. if idx != -1:
  18. output = output[: idx + len(stop_str)]
  19. return output