""" ===代码审计工程师=== 用于分析具体的源代码,包括数据流、控制流等 """ import json import re import uuid import xml.etree.ElementTree as ET from langchain_core.messages import SystemMessage from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableWithMessageHistory from langchain_openai import ChatOpenAI from langchain_community.chat_message_histories import ChatMessageHistory from agents.CAE.prompt import CAE_SYSTEM_PROMPT, CAE_HUMAN_PROMPT from logger import Logger class CAE: def __init__(self, base_url, api_key, model, process_output_callback): # LLM配置 self.llm = ChatOpenAI(base_url=base_url, api_key=api_key, model=model) self.session_id = uuid.uuid4().hex # 内存记忆 self.max_history_length = 10 self.history = ChatMessageHistory() # 提示词配置 self.system_prompt = CAE_SYSTEM_PROMPT self.human_prompt = CAE_HUMAN_PROMPT # 日志器配置 self.log = Logger(name='CAE', callback=process_output_callback) def audit(self, project_structure, project_module_division, result_output_callback, event): self.log.info('CAE开始审计项目代码') # 提示词模板 self.llm_tmpl = ChatPromptTemplate.from_messages([ SystemMessage(content=self.system_prompt), MessagesPlaceholder(variable_name='history'), HumanMessagePromptTemplate.from_template(template=self.human_prompt), ]) # 调用链配置 self.raw_chain = self.llm_tmpl | self.llm self.llm_chain = RunnableWithMessageHistory( self.raw_chain, lambda session_id: self.history, input_messages_key='content', history_messages_key='history', ) # 进入审计流程 input_content = 'continue' while True: if event.is_set(): return # 剔除更早的对话 while len(self.history.messages) > self.max_history_length: self.history.messages.pop(0) try: # 获取当前输出 input_dict = { 'content': input_content, 'history': self.history.messages, } config_dict = { 'configurable': {'session_id': self.session_id} } result = self.llm_chain.invoke(input_dict, config_dict) if event.is_set(): return # 解析动作指令 if xml_match := re.search(r'.*?', result.content, re.DOTALL): try: xml_content = xml_match.group(0) xml_content = re.sub( r'()(.*?)()', r'\1\3', xml_content, flags=re.DOTALL ) root = ET.fromstring(xml_content) action = root.find('action').text content = root.find('content').text if content and content.startswith(''): content = content[9:-3] except Exception as e: self.log.error(f'CAE动作指令不合法:尝试纠正') input_content = 'ILLEGAL OUTPUT' continue # 执行动作 try: if action == 'QUERY STRUCTURE': self.log.info('CAE请求查询项目结构') input_content = project_structure continue elif action == 'MODULE DIVISION': self.log.info('CAE请求查询项目模块') input_content = project_module_division continue elif action == 'QUERY SOURCE': self.log.info(f'CAE请求查询源代码:{content}') try: input_content = open(content, 'r', encoding='utf-8').read() except Exception as e: input_content = str(e) continue elif action == 'OUTPUT RESULT': self.log.warning('CAE输出代码审计结果') dict_content = eval(content) json_content = json.loads(json.dumps(dict_content)) output_content = f'漏洞类型:{json_content["漏洞类型"]}\n漏洞文件:{json_content["漏洞文件"]}\n相关代码:\n{json_content["相关代码"]}\n修复建议:\n{json_content["修复建议"]}\n' result_output_callback(output_content) input_content = 'continue' continue elif action == 'FINISH TASK': self.log.info('CAE完成项目代码审计') return else: self.log.error(f'CAE动作指令未定义:{action}') return except Exception as e: self.log.error(e) continue except Exception as e: self.log.error(e) continue