123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217 |
- """Tokenization classes for QWen."""
- import base64
- import unicodedata
- from pathlib import Path
- from typing import Collection, Dict, List, Set, Union
- import tiktoken
- from qwen_agent.log import logger
- VOCAB_FILES_NAMES = {'vocab_file': 'qwen.tiktoken'}
- PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
- ENDOFTEXT = '<|endoftext|>'
- IMSTART = '<|im_start|>'
- IMEND = '<|im_end|>'
- # as the default behavior is changed to allow special tokens in
- # regular texts, the surface forms of special tokens need to be
- # as different as possible to minimize the impact
- EXTRAS = tuple((f'<|extra_{i}|>' for i in range(205)))
- # changed to use actual index to avoid misconfiguration with vocabulary expansion
- SPECIAL_START_ID = 151643
- SPECIAL_TOKENS = tuple(enumerate(
- ((
- ENDOFTEXT,
- IMSTART,
- IMEND,
- ) + EXTRAS),
- start=SPECIAL_START_ID,
- ))
- SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
- def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
- with open(tiktoken_bpe_file, 'rb') as f:
- contents = f.read()
- return {
- base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)
- }
- class QWenTokenizer:
- """QWen tokenizer."""
- vocab_files_names = VOCAB_FILES_NAMES
- def __init__(
- self,
- vocab_file=None,
- errors='replace',
- extra_vocab_file=None,
- ):
- if not vocab_file:
- vocab_file = VOCAB_FILES_NAMES['vocab_file']
- self._decode_use_source_tokenizer = False
- # how to handle errors in decoding UTF-8 byte sequences
- # use ignore if you are in streaming inference
- self.errors = errors
- self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
- self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
- # try load extra vocab from file
- if extra_vocab_file is not None:
- used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
- extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
- for token, index in extra_mergeable_ranks.items():
- if token in self.mergeable_ranks:
- logger.info(f'extra token {token} exists, skipping')
- continue
- if index in used_ids:
- logger.info(f'the index {index} for extra token {token} exists, skipping')
- continue
- self.mergeable_ranks[token] = index
- # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
- enc = tiktoken.Encoding(
- 'Qwen',
- pat_str=PAT_STR,
- mergeable_ranks=self.mergeable_ranks,
- special_tokens=self.special_tokens,
- )
- assert len(self.mergeable_ranks) + len(
- self.special_tokens
- ) == enc.n_vocab, f'{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding'
- self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
- self.decoder.update({v: k for k, v in self.special_tokens.items()})
- self.tokenizer = enc # type: tiktoken.Encoding
- self.eod_id = self.tokenizer.eot_token
- self.im_start_id = self.special_tokens[IMSTART]
- self.im_end_id = self.special_tokens[IMEND]
- def __getstate__(self):
- # for pickle lovers
- state = self.__dict__.copy()
- del state['tokenizer']
- return state
- def __setstate__(self, state):
- # tokenizer is not python native; don't pass it; rebuild it
- self.__dict__.update(state)
- enc = tiktoken.Encoding(
- 'Qwen',
- pat_str=PAT_STR,
- mergeable_ranks=self.mergeable_ranks,
- special_tokens=self.special_tokens,
- )
- self.tokenizer = enc
- def __len__(self) -> int:
- return self.tokenizer.n_vocab
- def get_vocab(self) -> Dict[bytes, int]:
- return self.mergeable_ranks
- def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str]]]) -> List[int]:
- ids = []
- if isinstance(tokens, (str, bytes)):
- if tokens in self.special_tokens:
- return self.special_tokens[tokens]
- else:
- return self.mergeable_ranks.get(tokens)
- for token in tokens:
- if token in self.special_tokens:
- ids.append(self.special_tokens[token])
- else:
- ids.append(self.mergeable_ranks.get(token))
- return ids
- def tokenize(
- self,
- text: str,
- allowed_special: Union[Set, str] = 'all',
- disallowed_special: Union[Collection, str] = (),
- ) -> List[Union[bytes, str]]:
- """
- Converts a string in a sequence of tokens.
- Args:
- text (`str`):
- The sequence to be encoded.
- allowed_special (`Literal["all"]` or `set`):
- The surface forms of the tokens to be encoded as special tokens in regular texts.
- Default to "all".
- disallowed_special (`Literal["all"]` or `Collection`):
- The surface forms of the tokens that should not be in regular texts and trigger errors.
- Default to an empty tuple.
- Returns:
- `List[bytes|str]`: The list of tokens.
- """
- tokens = []
- text = unicodedata.normalize('NFC', text)
- # this implementation takes a detour: text -> token id -> token surface forms
- for t in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special):
- tokens.append(self.decoder[t])
- return tokens
- def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
- """
- Converts a sequence of tokens in a single string.
- """
- text = ''
- temp = b''
- for t in tokens:
- if isinstance(t, str):
- if temp:
- text += temp.decode('utf-8', errors=self.errors)
- temp = b''
- text += t
- elif isinstance(t, bytes):
- temp += t
- else:
- raise TypeError('token should only be of type types or str')
- if temp:
- text += temp.decode('utf-8', errors=self.errors)
- return text
- @property
- def vocab_size(self):
- return self.tokenizer.n_vocab
- def _decode(
- self,
- token_ids: Union[int, List[int]],
- skip_special_tokens: bool = False,
- errors: str = None,
- ) -> str:
- if isinstance(token_ids, int):
- token_ids = [token_ids]
- if skip_special_tokens:
- token_ids = [i for i in token_ids if i < self.eod_id]
- return self.tokenizer.decode(token_ids, errors=errors or self.errors)
- def encode(self, text: str) -> List[int]:
- return self.convert_tokens_to_ids(self.tokenize(text))
- def count_tokens(self, text: str) -> int:
- return len(self.tokenize(text))
- def truncate(self, text: str, max_token: int, start_token: int = 0) -> str:
- token_list = self.tokenize(text)
- token_list = token_list[start_token:min(len(token_list), start_token + max_token)]
- return self.convert_tokens_to_string(token_list)
- tokenizer = QWenTokenizer(Path(__file__).resolve().parent / 'qwen.tiktoken')
- def count_tokens(text: str) -> int:
- return tokenizer.count_tokens(text)
|