visualization.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import logging
  2. import os
  3. import re
  4. import torch
  5. from config import get_model, get_react_parser
  6. from utils.data_utils import load_jsonl, save_jsonl
  7. torch.manual_seed(1234)
  8. EVAL_VISUAL_PROMPT_ZH = """请判断图片是否与下面的[问题]一致,如果一致则回复“right”,不一致则回复“wrong”。
  9. [问题]:{query}
  10. """
  11. EVAL_VISUAL_PROMPT_EN = """Please judge whether the image is consistent with the [Question] below, if it is consistent then reply "right", if not then reply "wrong".
  12. [Question]: {query}
  13. """
  14. visualization_code_correctness = {
  15. 'visualization-hard': None,
  16. 'visualization-easy': None,
  17. }
  18. def qwen_vl_inference(qwen_vl, imgs=[], prompt=''):
  19. inputs = []
  20. for img in imgs:
  21. inputs.append({'image': img})
  22. inputs.append({'text': prompt})
  23. logging.info('Eval'.center(60, '-'))
  24. logging.info(inputs)
  25. query = qwen_vl.tokenizer.from_list_format(inputs)
  26. response, history = qwen_vl.model.chat(qwen_vl.tokenizer, query=query, history=None)
  27. logging.info(response)
  28. logging.info('='*60)
  29. return response
  30. def extract_images(text):
  31. regex = re.compile(r'!\[fig-(.+)\]\((.+)\)')
  32. results = re.findall(regex, text)
  33. images = []
  34. for res in results:
  35. assert len(res) == 2
  36. if os.path.exists(res[1]):
  37. images.append(res[1])
  38. return images
  39. def check_images_observation(text, images, model_name):
  40. start_flag = get_react_parser(model_name).observation
  41. for image in images:
  42. logging.info('Image'.center(60, '-'))
  43. logging.info(image)
  44. end_idx = text.find(image)
  45. tmp_text = text[:end_idx+len(image)]
  46. start_idx = tmp_text.rfind(start_flag)
  47. check_text = tmp_text[start_idx + len(start_flag):]
  48. logging.info('Observation'.center(60, '-'))
  49. logging.info(check_text)
  50. # As long as there exists correctly executed observation, we consider `True`
  51. if 'error:' not in check_text and 'Traceback' not in check_text:
  52. return True
  53. return False
  54. eval_visual_prompt = {
  55. 'zh': EVAL_VISUAL_PROMPT_ZH,
  56. 'en': EVAL_VISUAL_PROMPT_EN
  57. }
  58. def eval_visualization_acc(output_fname, model_name):
  59. qwen_vl = get_model('qwen-vl-chat')
  60. one_action, one_action_right = 0, 0
  61. zero_action, zero_action_right = 0, 0
  62. data_list = load_jsonl(output_fname)
  63. for item in data_list:
  64. if 'visualization' not in item['tags']:
  65. continue
  66. item['vis_acc'] = False
  67. if '<|im_end|>' in item['query']:
  68. one_action += 1
  69. prompt = item['query'].split('<|im_end|>')[0]
  70. else:
  71. zero_action += 1
  72. prompt = item['query']
  73. images = extract_images(item['gen'])
  74. if images and check_images_observation(item['gen'], images, model_name):
  75. input_prompt = eval_visual_prompt[item.get('lang', 'en')]
  76. format_prompt = input_prompt.format(query=prompt)
  77. output = qwen_vl_inference(qwen_vl, images, format_prompt)
  78. if 'right' in output.lower():
  79. item['vis_acc'] = True
  80. if '<|im_end|>' in item['query']:
  81. one_action_right += 1
  82. else:
  83. zero_action_right += 1
  84. logging.info('*'*60)
  85. logging.info('{:^60}'.format('Visualization Acc.'))
  86. logging.info('*'*60)
  87. logging.info('Visualization-Hard count={}, Visualization-Hard right count={}, Visualization-Hard acc={:.2f}'.format(zero_action, zero_action_right, zero_action_right/zero_action*100))
  88. logging.info('Visualization-Easy count={}, Visualization-Easy right count={}, Visualization-Easy acc={:.2f}'.format(one_action, one_action_right, one_action_right/one_action*100))
  89. logging.info('all count={}, all right={}, all acc={:.2f}'.format(zero_action+one_action, zero_action_right+one_action_right, (zero_action_right+one_action_right)/(zero_action+one_action)*100))
  90. visualization_code_correctness['visualization-hard'] = zero_action_right/zero_action*100
  91. visualization_code_correctness['visualization-easy'] = one_action_right/one_action*100
  92. error_data_list = [item for item in data_list if 'visualization' in item['tags'] and not item['vis_acc']]
  93. error_data_output_fname = os.path.splitext(output_fname)[0] + '_vis_error.jsonl'
  94. save_jsonl(error_data_list, error_data_output_fname)
  95. return visualization_code_correctness