schema.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from typing import List, Literal, Optional, Tuple, Union
  2. from pydantic import BaseModel, field_validator, model_validator
  3. DEFAULT_SYSTEM_MESSAGE = 'You are a helpful assistant.'
  4. ROLE = 'role'
  5. CONTENT = 'content'
  6. NAME = 'name'
  7. SYSTEM = 'system'
  8. USER = 'user'
  9. ASSISTANT = 'assistant'
  10. FUNCTION = 'function'
  11. FILE = 'file'
  12. IMAGE = 'image'
  13. class BaseModelCompatibleDict(BaseModel):
  14. def __getitem__(self, item):
  15. return getattr(self, item)
  16. def __setitem__(self, key, value):
  17. setattr(self, key, value)
  18. def model_dump(self, **kwargs):
  19. return super().model_dump(exclude_none=True, **kwargs)
  20. def model_dump_json(self, **kwargs):
  21. return super().model_dump_json(exclude_none=True, **kwargs)
  22. def get(self, key, default=None):
  23. try:
  24. value = getattr(self, key)
  25. if value:
  26. return value
  27. else:
  28. return default
  29. except AttributeError:
  30. return default
  31. def __str__(self):
  32. return f'{self.model_dump()}'
  33. class FunctionCall(BaseModelCompatibleDict):
  34. name: str
  35. arguments: str
  36. def __init__(self, name: str, arguments: str):
  37. super().__init__(name=name, arguments=arguments)
  38. def __repr__(self):
  39. return f'FunctionCall({self.model_dump()})'
  40. class ContentItem(BaseModelCompatibleDict):
  41. text: Optional[str] = None
  42. image: Optional[str] = None
  43. file: Optional[str] = None
  44. def __init__(self, text: Optional[str] = None, image: Optional[str] = None, file: Optional[str] = None):
  45. super().__init__(text=text, image=image, file=file)
  46. @model_validator(mode='after')
  47. def check_exclusivity(self):
  48. provided_fields = 0
  49. if self.text is not None:
  50. provided_fields += 1
  51. if self.image:
  52. provided_fields += 1
  53. if self.file:
  54. provided_fields += 1
  55. if provided_fields != 1:
  56. raise ValueError("Exactly one of 'text', 'image', or 'file' must be provided.")
  57. return self
  58. def __repr__(self):
  59. return f'ContentItem({self.model_dump()})'
  60. def get_type_and_value(self) -> Tuple[Literal['text', 'image', 'file'], str]:
  61. (t, v), = self.model_dump().items()
  62. assert t in ('text', 'image', 'file')
  63. return t, v
  64. @property
  65. def type(self) -> Literal['text', 'image', 'file']:
  66. t, v = self.get_type_and_value()
  67. return t
  68. @property
  69. def value(self) -> str:
  70. t, v = self.get_type_and_value()
  71. return v
  72. class Message(BaseModelCompatibleDict):
  73. role: str
  74. content: Union[str, List[ContentItem]]
  75. name: Optional[str] = None
  76. function_call: Optional[FunctionCall] = None
  77. def __init__(self,
  78. role: str,
  79. content: Optional[Union[str, List[ContentItem]]],
  80. name: Optional[str] = None,
  81. function_call: Optional[FunctionCall] = None,
  82. **kwargs):
  83. if content is None:
  84. content = ''
  85. super().__init__(role=role, content=content, name=name, function_call=function_call)
  86. def __repr__(self):
  87. return f'Message({self.model_dump()})'
  88. @field_validator('role')
  89. def role_checker(cls, value: str) -> str:
  90. if value not in [USER, ASSISTANT, SYSTEM, FUNCTION]:
  91. raise ValueError(f'{value} must be one of {",".join([USER, ASSISTANT, SYSTEM, FUNCTION])}')
  92. return value