role.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # -*- coding: utf-8 -*-
  2. from .__load__ import *
  3. from langchain_core.prompts import ChatPromptTemplate
  4. from langchain_core.output_parsers import StrOutputParser
  5. from langchain_core.runnables import RunnablePassthrough
  6. class Role(object):
  7. def init(self, site_id, role_id, uid):
  8. self.info = Demeter.service('common').one('role', id=role_id)
  9. if self.info:
  10. self.model = Demeter.service('common').one('lang_model', id=self.info['lang_model_id'])
  11. self.db = None
  12. self.piece = None
  13. #self.memory()
  14. # 知识库挂载
  15. data = Demeter.service('data').init(site_id)
  16. context = data.load('similarity', {'k':5, 'fetch_k':50, 'filter': {'role_id': role_id, 'uid' : uid}})
  17. #sample = data.load('similarity', {'k':5, 'fetch_k':50, 'filter': {'role_id': role_id, 'uid': 'sample'}})
  18. print(context)
  19. self.piece = {"context": context | self.format_docs, "question": RunnablePassthrough()}
  20. return self
  21. # 写入记忆
  22. def write(self, memory):
  23. pass
  24. # 挂载工具
  25. def tool(self, tool):
  26. pass
  27. def set(self, prompts):
  28. chain = ChatPromptTemplate.from_template(prompts)
  29. if not self.piece:
  30. self.piece = chain
  31. else:
  32. self.piece = self.piece | chain
  33. return self
  34. def out(self, query, type = []):
  35. if self.info:
  36. #self.info['persona'] = '你是一个精美时尚杂志社的编辑,根据以下上下文来回答这个问题{context}'
  37. template = """你是一个精美时尚杂志社的编辑,根据以下上下文来回答这个问题:
  38. {context}
  39. Question: {question},请用中文输出答案。
  40. """
  41. template = """你是一位专业医生。以下是病人的病例内容,请根据医学规范生成详细分析报告。
  42. 病例内容:
  43. {context}
  44. 请根据上面提供的病例内容生成报告。根据病人的核心关注需求提供解决方案。
  45. 报告要求:
  46. 1. 核心健康问题汇总
  47. 2. 潜在风险与关联性分析
  48. 3. 综合健康建议
  49. 4. 紧急情况预警
  50. 5. 解决方案
  51. 请以word格式输出,我好直接生成word。
  52. """
  53. self.set(template)
  54. self.model = Demeter.service(self.model['channel'], 'llm').load(model='deepseek-r1', streaming=True)
  55. full_report = ""
  56. chain = (self.piece | self.model | StrOutputParser())
  57. for chunk in chain.stream(query):
  58. print(chunk, end="")
  59. full_report += chunk
  60. #self.save_docx(full_report)
  61. def format_docs(self, docs):
  62. return "\n\n".join([d.page_content for d in docs])
  63. def save_docx(self, content):
  64. patient_id = self.info.get('uid', 'unknown') # 或者 role_id
  65. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  66. filename = f"report_{patient_id}_{timestamp}.docx"
  67. output_dir = "reports"
  68. os.makedirs(output_dir, exist_ok=True)
  69. filepath = os.path.join(output_dir, filename)
  70. doc = Document()
  71. doc.add_heading('诊断报告', 0)
  72. doc.add_paragraph(report_text)
  73. doc.save(filepath)
  74. print(f"\n\n📝 报告已保存为:{filepath}")
  75. # 生成角色
  76. def create(self, site_id, uid, name, persona, lang_model_id, data, tool):
  77. db = Demeter.db('role')
  78. db.site_id = site_id
  79. db.create_uid = create_uid
  80. db.owner_uid = owner_uid
  81. db.persona = persona
  82. db.lang_model_id = lang_model_id
  83. id = db.insert()
  84. if len(data) > 0:
  85. for key, value in enumerate(data):
  86. pass