Refactor code to deprecate text embedding

This commit is contained in:
2025-02-11 02:34:55 +08:00
parent 5467f72bd7
commit 32d82c99ec
17 changed files with 414 additions and 306 deletions

View File

@@ -1,13 +1,18 @@
# MollyAudit ![](https://socialify.git.ci/yv1ing/MollyAudit/image?language=1&owner=1&name=1&stargazers=1&theme=Light)
An automated code auditing tool powered by langchain. # What's this?
Tool interface: This is an automatic code auditing tool based on LangChain and driven by LLMs.
![](assets/img-01.png) **Basic idea:**
Result Export: I designed two agents, one called Code Audit Engineer (CAE) and another called Code Software Architect (CSA). CSA is responsible for functionally dividing the project structure to be audited and then passing it to CAE as a reference.
![](assets/img-02.png) After CAE collects the project structure, it independently determines the actions of the audit code and outputs the action instructions in a specific format. The local program parses the instructions and sends the relevant results, thus forming an automatic workflow.
I welcome your suggestions for interesting tools :smile: At present, this tool is still in an immature stage and there are still many problems. If you have better suggestions, please contact me!
# Tool interface
![](assets/img-01.png)

143
agents/CAE/__init__.py Normal file
View File

@@ -0,0 +1,143 @@
"""
===代码审计工程师===
用于分析具体的源代码,包括数据流、控制流等
"""
import json
import re
import uuid
import xml.etree.ElementTree as ET
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableWithMessageHistory
from langchain_openai import ChatOpenAI
from langchain_community.chat_message_histories import ChatMessageHistory
from agents.CAE.prompt import CAE_SYSTEM_PROMPT, CAE_HUMAN_PROMPT
from logger import Logger
class CAE:
def __init__(self, base_url, api_key, model, process_output_callback):
# LLM配置
self.llm = ChatOpenAI(base_url=base_url, api_key=api_key, model=model)
self.session_id = uuid.uuid4().hex
# 内存记忆
self.max_history_length = 10
self.history = ChatMessageHistory()
# 提示词配置
self.system_prompt = CAE_SYSTEM_PROMPT
self.human_prompt = CAE_HUMAN_PROMPT
# 日志器配置
self.log = Logger(name='CAE', callback=process_output_callback)
def audit(self, project_structure, project_module_division, result_output_callback, event):
self.log.info('CAE开始审计项目代码')
# 提示词模板
self.llm_tmpl = ChatPromptTemplate.from_messages([
SystemMessage(content=self.system_prompt),
MessagesPlaceholder(variable_name='history'),
HumanMessagePromptTemplate.from_template(template=self.human_prompt),
])
# 调用链配置
self.raw_chain = self.llm_tmpl | self.llm
self.llm_chain = RunnableWithMessageHistory(
self.raw_chain,
lambda session_id: self.history,
input_messages_key='content',
history_messages_key='history',
)
# 进入审计流程
input_content = 'continue'
while True:
if event.is_set():
return
# 剔除更早的对话
while len(self.history.messages) > self.max_history_length:
self.history.messages.pop(0)
try:
# 获取当前输出
input_dict = {
'content': input_content,
'history': self.history.messages,
}
config_dict = {
'configurable': {'session_id': self.session_id}
}
result = self.llm_chain.invoke(input_dict, config_dict)
if event.is_set():
return
# 解析动作指令
if xml_match := re.search(r'<root>.*?</root>', result.content, re.DOTALL):
try:
xml_content = xml_match.group(0)
xml_content = re.sub(
r'(<content>)(.*?)(</content>)',
r'\1<![CDATA[\2]]>\3',
xml_content,
flags=re.DOTALL
)
root = ET.fromstring(xml_content)
action = root.find('action').text
content = root.find('content').text
if content and content.startswith('<![CDATA[') and content.endswith(']]>'):
content = content[9:-3]
except Exception as e:
self.log.error(f'CAE动作指令不合法尝试纠正')
input_content = 'ILLEGAL OUTPUT'
continue
# 执行动作
try:
if action == 'QUERY STRUCTURE':
self.log.info('CAE请求查询项目结构')
input_content = project_structure
continue
elif action == 'MODULE DIVISION':
self.log.info('CAE请求查询项目模块')
input_content = project_module_division
continue
elif action == 'QUERY SOURCE':
self.log.info(f'CAE请求查询源代码{content}')
try:
input_content = open(content, 'r', encoding='utf-8').read()
except Exception as e:
input_content = str(e)
continue
elif action == 'OUTPUT RESULT':
self.log.warning('CAE输出代码审计结果')
dict_content = eval(content)
json_content = json.loads(json.dumps(dict_content))
output_content = f'漏洞类型:{json_content["漏洞类型"]}\n漏洞文件:{json_content["漏洞文件"]}\n相关代码:\n{json_content["相关代码"]}\n修复建议:\n{json_content["修复建议"]}\n'
result_output_callback(output_content)
input_content = 'continue'
continue
elif action == 'FINISH TASK':
self.log.info('CAE完成项目代码审计')
return
else:
self.log.error(f'CAE动作指令未定义{action}')
return
except Exception as e:
self.log.error(e)
continue
except Exception as e:
self.log.error(e)
continue

View File

@@ -1,4 +1,4 @@
SYSTEM_PROMPT = """ CAE_SYSTEM_PROMPT = """
You are a professional code audit security expert, responsible for helping users audit possible vulnerabilities and security issues in source code. You are a professional code audit security expert, responsible for helping users audit possible vulnerabilities and security issues in source code.
You will perform code audits according to the following process: You will perform code audits according to the following process:
1. Query project structure 1. Query project structure
@@ -8,29 +8,40 @@ You input the action command in the following format, and the user will send you
<content></content> <content></content>
</root> </root>
2. Query the source code 2. Query module division
You input the action command in the following format, and the user will send you the project module division:
<root>
<action>MODULE DIVISION</action>
<content></content>
</root>
3. Query the source code
You input the action command in the following format, and the user will send you the source code you need below: You input the action command in the following format, and the user will send you the source code you need below:
<root> <root>
<action>QUERY SOURCE</action> <action>QUERY SOURCE</action>
<content>the absolute path of the file you want to query</content> <content>the absolute path of the file you want to query</content>
</root> </root>
3. Output code audit results 4. Output code audit results
You input the code audit results in the following format, and the user will send you "ok", then you can proceed to the next step of the audit: You input the code audit results in the following format, and the user will send you "continue", then you can proceed to the next step of the audit:
<root> <root>
<action>OUTPUT RESULT</action> <action>OUTPUT RESULT</action>
<content>the audit results you want to output</content> <content>the audit results you want to output</content>
</root> </root>
4. Finish audit task 5. Finish audit task
When you are sure that all source code files have been audited, you can output the action instructions to end the task in the following format: When you are sure that all source code files have been audited, you can output the action instructions to end the task in the following format:
<root> <root>
<action>FINISH TASK</action> <action>FINISH TASK</action>
<content></content> <content></content>
</root> </root>
All your output can only be one of the 4 actions mentioned above. Any other form of output is strictly prohibited. Emphasis:
1. The part wrapped in square brackets [] is what you need to fill in according to the actual situation, do not use square brackets when outputting;
2. All your output can only be one of the 5 actions mentioned above. Any other form of output is strictly prohibited;
3. Only output audit results with vulnerabilities, and prohibit output without vulnerabilities!
4. During the audit process, you need to divide the information according to the provided modules and carefully analyze the control flow and data flow of the program. In this process, you can query the contents of multiple files. Remember to give the absolute path of the file to be queried in the format.
5. The audit task can be completed only after all source code files have been audited;
Some additional information, which are some specifications when you perform actions: Some additional information, which are some specifications when you perform actions:
1. The project structure format sent to you is as follows. You need to construct the complete absolute path of the file you want to query based on these hierarchical relationships: 1. The project structure format sent to you is as follows. You need to construct the complete absolute path of the file you want to query based on these hierarchical relationships:
@@ -42,22 +53,32 @@ Some additional information, which are some specifications when you perform acti
- dir_3/ - dir_3/
- 3.php - 3.php
2. When you output the code audit results, you must use Chinese output and follow the following format(Python dict): 2. The project module division format provided by the user is as follows, you can use this as the basis for preliminary code audit:
{'漏洞类型': 'SQL Injection', '漏洞文件': 'main.java', '相关代码': '```java\nString id=request.getParameter("id");\nres = st.executeQuery("SELECT* FROM\"IWEBSEC\".\"user\" WHERE \"id\"="+id);\n```', '修复建议': 'your suggestions...'} HelloWorld Functional division
1 Configuration
- Package name: com.best.hello.config
- Main function: Web application configuration, including MVC and login interceptor.
- Absolute file path: C:/Users/yvling/Desktop/HelloWorld/src/main/java/com/best/hello/config
2 Controller
- Package name: com.best.hello.controller
- Main function: Demonstrating various common web security vulnerabilities through different controllers.
- Absolute file path: C:/Users/yvling/Desktop/HelloWorld/src/main/java/com/best/hello/controller
Most important: Only output audit results with vulnerabilities, and prohibit output without vulnerabilities! 3. When you output the code audit results, you must use Chinese output and follow the following format(Python dict):
{'漏洞类型': 'SQL Injection', '漏洞文件': 'main.java', '相关代码': '```java\nString id=request.getParameter("id");\nres = st.executeQuery("SELECT* FROM\"IWEBSEC\".\"user\" WHERE \"id\"="+id);\n```', '修复建议': 'your suggestions...'}
Some Mandatory regulations: Some Mandatory regulations:
1. Output Format: 1. Output Format:
a. Strictly use the predefined XML tag structure a. Strictly use the predefined XML tag structure
b. Any Markdown symbols are not allowed b. Any Markdown symbols are not allowed
c. No line breaks in the content field c. No line breaks in the content field
d. Do not use quotation marks around the output
2. Language Standards: 2. Language Standards:
a. Technical terms are kept in their original English a. Technical terms are kept in their original English
b. Vulnerability descriptions must be in Chinese b. Vulnerability descriptions must be in Chinese
3. Interaction restrictions: 3. Interaction restrictions:
a. Any content outside the output process is prohibited a. Any content outside the output process is prohibited
b. Autonomously advance the audit process when receiving "nothing" or "ok" b. Autonomously advance the audit process when receiving "continue", such as QUERY SOURCE
c. Vulnerabilities must be output immediately c. Vulnerabilities must be output immediately
4. Error handling: 4. Error handling:
a. When receiving the "ILLEGAL OUTPUT" prompt, terminate the current output immediately and recheck the format specification before continuing a. When receiving the "ILLEGAL OUTPUT" prompt, terminate the current output immediately and recheck the format specification before continuing
@@ -68,3 +89,7 @@ Some Mandatory regulations:
d. For vulnerabilities that may span files, the audit can only begin after the relevant files have been queried as needed d. For vulnerabilities that may span files, the audit can only begin after the relevant files have been queried as needed
e. Only output audit results with vulnerabilities, and prohibit output without vulnerabilities e. Only output audit results with vulnerabilities, and prohibit output without vulnerabilities
""" """
CAE_HUMAN_PROMPT = """
{content}
"""

48
agents/CSA/__init__.py Normal file
View File

@@ -0,0 +1,48 @@
"""
===软件架构师===
用于分析项目的整体框架,抽取出清晰的项目结构和功能划分
"""
from langchain_core.messages import SystemMessage
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_community.callbacks.manager import get_openai_callback
from agents.CSA.prompt import CSA_SYSTEM_PROMPT, CSA_HUMAN_PROMPT
from logger import Logger
class CSA:
def __init__(self, base_url, api_key, model, process_output_callback):
# LLM配置
self.llm = ChatOpenAI(base_url=base_url, api_key=api_key, model=model)
# 提示词配置
self.system_prompt = CSA_SYSTEM_PROMPT
self.human_prompt = CSA_HUMAN_PROMPT
# 日志器配置
self.log = Logger(name='CSA', callback=process_output_callback)
def analyse(self, project_structure):
self.log.info('CSA开始分析项目模块')
# 提示词模板
self.llm_tmpl = ChatPromptTemplate.from_messages([
SystemMessage(content=self.system_prompt),
HumanMessagePromptTemplate.from_template(template=self.human_prompt),
])
# 调用链配置
self.llm_chain = self.llm_tmpl | self.llm
# 获取分析结果
with get_openai_callback() as cb:
result = self.llm_chain.invoke({'project_structure': project_structure})
# TODO: 接入token用量统计
# print(f"请求消耗的输入 token 数: {cb.prompt_tokens}")
# print(f"请求消耗的输出 token 数: {cb.completion_tokens}")
# print(f"请求总共消耗的 token 数: {cb.total_tokens}")
self.log.info('CSA完成分析项目模块')
return result.content

32
agents/CSA/prompt.py Normal file
View File

@@ -0,0 +1,32 @@
CSA_SYSTEM_PROMPT = """
You are a senior software architect, your responsibilities are:
The user provides you with the directory structure of the project, you need to analyze the project, summarize the project function division, and output the results in the following format (Markdown):
[Project name] Functional division
[Module number] [Module name]
- Package name: [Package name]
- Main function: [Main function]
- Absolute file path: [Absolute file path]
Emphasis:
1. The part wrapped in square brackets [] is what you need to fill in according to the actual situation, do not use square brackets when outputting;
2. One package (directory) uses one number;
3. The output absolute path refers to the absolute path of the source code file. All source code files in the same directory must be output;
For example:
HelloWorld Functional division
1 Configuration
- Package name: com.example.hello.config
- Main function: Web application configuration, including MVC and login interceptor.
- Absolute file path: C:/Users/yvling/Desktop/HelloWorld/src/main/java/com/example/hello/config.java
2 Controller
- Package name: com.example.hello.controller
- Main function: Demonstrating various common web security vulnerabilities through different controllers.
- Absolute file path: C:/Users/yvling/Desktop/HelloWorld/src/main/java/com/example/hello/controller.java
"""
CSA_HUMAN_PROMPT = """
The project directory structure provided by the user is as follows:
{project_structure}
Please start the analysis and output according to the format.
"""

0
agents/__init__.py Normal file
View File

View File

@@ -1,6 +1,9 @@
import os import os
import warnings import warnings
from audit import Audit import logger
from agents.CAE import CAE
from agents.CSA import CSA
from utils import build_directory_tree
warnings.simplefilter('ignore', FutureWarning) warnings.simplefilter('ignore', FutureWarning)
@@ -11,8 +14,8 @@ config_file_path = os.path.join(home_dir, config_file_name)
GLOBAL_CONFIG = { GLOBAL_CONFIG = {
"base_url": "https://openai.com/v1", "base_url": "https://openai.com/v1",
"api_key": "", "api_key": "",
"reasoning_model": "o3-mini-all", "csa_model": "gpt-4o",
"embedding_model": "text-embedding-3-small" "cae_model": "o3-mini-all",
} }
@@ -32,7 +35,7 @@ def load_config():
file.write(f"{key}={value}\n") file.write(f"{key}={value}\n")
def update_config(key, value): def real_update_config(key, value):
global GLOBAL_CONFIG global GLOBAL_CONFIG
GLOBAL_CONFIG[key] = value GLOBAL_CONFIG[key] = value
@@ -41,8 +44,18 @@ def update_config(key, value):
file.write(f"{k}={v}\n") file.write(f"{k}={v}\n")
def audit_code(base_url, api_key, src_root, language, reasoning_model, embedding_model, process_output_callback, def audit_code(base_url, api_key, src_root, language, cae_model, csa_model, process_output_callback,
result_output_callback, event): result_output_callback, event):
audit = Audit(base_url, api_key, reasoning_model, embedding_model, process_output_callback, result_output_callback) log = logger.Logger('app', process_output_callback)
audit.build_directory_tree(src_root, language)
audit.audit(event) csa = CSA(base_url=base_url, api_key=api_key, model=csa_model, process_output_callback=process_output_callback)
cae = CAE(base_url=base_url, api_key=api_key, model=cae_model, process_output_callback=process_output_callback)
project_structure = build_directory_tree(src_root, language, process_output_callback)
if project_structure is None:
log.error('未找到源代码文件')
return
project_module_division = csa.analyse(project_structure=project_structure)
cae.audit(project_structure=project_structure, project_module_division=project_module_division, event=event,
result_output_callback=result_output_callback)

View File

@@ -1,4 +1,4 @@
LANGUAGE = { SUPPORT_LANGUAGE = {
'c': ['.c'], 'c': ['.c'],
'cpp': ['.cc', '.cpp'], 'cpp': ['.cc', '.cpp'],
'go': ['.go'], 'go': ['.go'],

View File

@@ -1,9 +1,9 @@
import os import os
import re import re
import datetime
import threading import threading
from threading import Event from threading import Event
from app import audit_code, update_config, GLOBAL_CONFIG from app import audit_code, real_update_config, GLOBAL_CONFIG
from app.utils import get_now_date
from logger import Logger from logger import Logger
from PyQt6.QtGui import QColor, QGuiApplication, QTextCursor from PyQt6.QtGui import QColor, QGuiApplication, QTextCursor
from PyQt6.QtWidgets import ( from PyQt6.QtWidgets import (
@@ -18,7 +18,6 @@ from PyQt6.QtWidgets import (
QComboBox QComboBox
) )
BACKGROUND_COLOR = '#dcdcdc' BACKGROUND_COLOR = '#dcdcdc'
ANSI_ESCAPE = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') ANSI_ESCAPE = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
ANSI_COLOR_REGEX = re.compile(r'\x1B\[(?:([0-9]+);)?([0-9]+)m') ANSI_COLOR_REGEX = re.compile(r'\x1B\[(?:([0-9]+);)?([0-9]+)m')
@@ -31,6 +30,12 @@ ANSI_COLOR_MAP = {
} }
def get_now_date():
now = datetime.datetime.now()
formatted = now.strftime("%Y-%m-%d-%H-%M-%S")
return formatted
def convert_ansi_to_rich_text(text): def convert_ansi_to_rich_text(text):
segments = [] segments = []
pos = 0 pos = 0
@@ -57,7 +62,8 @@ def convert_ansi_to_rich_text(text):
class MainWindow(QWidget): class MainWindow(QWidget):
def __init__(self): def __init__(self):
self.event = Event() self.event = Event()
self.log = Logger('ui', callback=self.process_output_callback) self.log = Logger('app', callback=self.process_output_callback)
super().__init__() super().__init__()
self.init_ui() self.init_ui()
@@ -100,14 +106,14 @@ class MainWindow(QWidget):
main_layout.addLayout(config_layout) main_layout.addLayout(config_layout)
model_layout = QHBoxLayout() model_layout = QHBoxLayout()
self.reasoning_model_label = QLabel('推理模型:') self.cae_model_label = QLabel('模块分析模型:')
self.reasoning_model_input = QLineEdit() self.cae_model_input = QLineEdit()
self.embedding_model_label = QLabel('嵌入模型:') self.csa_model_label = QLabel('代码审计模型:')
self.embedding_model_input = QLineEdit() self.csa_model_input = QLineEdit()
model_layout.addWidget(self.reasoning_model_label) model_layout.addWidget(self.cae_model_label)
model_layout.addWidget(self.reasoning_model_input) model_layout.addWidget(self.cae_model_input)
model_layout.addWidget(self.embedding_model_label) model_layout.addWidget(self.csa_model_label)
model_layout.addWidget(self.embedding_model_input) model_layout.addWidget(self.csa_model_input)
main_layout.addLayout(model_layout) main_layout.addLayout(model_layout)
# 按钮部分 # 按钮部分
@@ -156,7 +162,9 @@ class MainWindow(QWidget):
# 导出结果 # 导出结果
export_button_layout = QHBoxLayout() export_button_layout = QHBoxLayout()
link_label = QLabel('联系作者:<a href="https://github.com/yv1ing">Github</a> <a href=mailto:me@yvling.cn>Email</a>') link_label = QLabel(
'联系作者:<a href="https://github.com/yv1ing">Github</a> <a href=mailto:me@yvling.cn>Email</a>'
)
link_label.setOpenExternalLinks(True) link_label.setOpenExternalLinks(True)
export_button_layout.addWidget(link_label) export_button_layout.addWidget(link_label)
export_button_layout.addStretch(1) export_button_layout.addStretch(1)
@@ -169,8 +177,8 @@ class MainWindow(QWidget):
# 加载配置 # 加载配置
self.base_url_input.setText(GLOBAL_CONFIG['base_url']) self.base_url_input.setText(GLOBAL_CONFIG['base_url'])
self.api_key_input.setText(GLOBAL_CONFIG['api_key']) self.api_key_input.setText(GLOBAL_CONFIG['api_key'])
self.reasoning_model_input.setText(GLOBAL_CONFIG['reasoning_model']) self.cae_model_input.setText(GLOBAL_CONFIG['cae_model'])
self.embedding_model_input.setText(GLOBAL_CONFIG['embedding_model']) self.csa_model_input.setText(GLOBAL_CONFIG['csa_model'])
def closeEvent(self, event): def closeEvent(self, event):
self.event.set() self.event.set()
@@ -182,14 +190,13 @@ class MainWindow(QWidget):
def update_config(self): def update_config(self):
base_url = self.base_url_input.text() base_url = self.base_url_input.text()
api_key = self.api_key_input.text() api_key = self.api_key_input.text()
reasoning_model = self.reasoning_model_input.text() cae_model = self.cae_model_input.text()
embedding_model = self.embedding_model_input.text() csa_model = self.csa_model_input.text()
update_config('base_url', base_url)
update_config('api_key', api_key)
update_config('reasoning_model', reasoning_model)
update_config('embedding_model', embedding_model)
real_update_config('base_url', base_url)
real_update_config('api_key', api_key)
real_update_config('cae_model', cae_model)
real_update_config('csa_model', csa_model)
self.log.info('更新配置成功') self.log.info('更新配置成功')
def select_directory(self): def select_directory(self):
@@ -234,8 +241,8 @@ class MainWindow(QWidget):
selected_lang = self.lang_combobox.currentText() selected_lang = self.lang_combobox.currentText()
base_url = self.base_url_input.text() base_url = self.base_url_input.text()
api_key = self.api_key_input.text() api_key = self.api_key_input.text()
reasoning_model = self.reasoning_model_input.text() cae_model = self.cae_model_input.text()
embedding_model = self.embedding_model_input.text() csa_model = self.csa_model_input.text()
if not selected_dir or not base_url or not api_key: if not selected_dir or not base_url or not api_key:
self.log.error('请确保项目目录、接口地址和模型密钥等都已填写') self.log.error('请确保项目目录、接口地址和模型密钥等都已填写')
@@ -250,8 +257,8 @@ class MainWindow(QWidget):
api_key, api_key,
selected_dir, selected_dir,
selected_lang, selected_lang,
reasoning_model, cae_model,
embedding_model, csa_model,
self.process_output_callback, self.process_output_callback,
self.result_output_callback, self.result_output_callback,
self.event self.event

View File

@@ -1,7 +0,0 @@
import datetime
def get_now_date():
now = datetime.datetime.now()
formatted = now.strftime("%Y-%m-%d-%H-%M-%S")
return formatted

Binary file not shown.

Before

Width:  |  Height:  |  Size: 121 KiB

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 60 KiB

View File

@@ -1,233 +0,0 @@
import json
import os
import re
import uuid
import xml.etree.ElementTree as ET
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.messages import SystemMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.vectorstores import FAISS
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
from langchain_text_splitters import CharacterTextSplitter
from logger import Logger
from audit import callback
from audit.prompt import SYSTEM_PROMPT
from audit.language import LANGUAGE
xml_pattern = r'<root>.*?</root>'
class Audit:
def __init__(self, base_url, api_key, reasoning_model, embedding_model, process_output_callback,
result_output_callback):
self.raw_chain = None
self.directory_tree = None
self.reasoning_model = reasoning_model
self.embedding_model = embedding_model
self.process_output_callback = process_output_callback
self.result_output_callback = result_output_callback
self.chat_history = ChatMessageHistory()
self.session_id = uuid.uuid4().hex
self.response_callback = callback.CustomCallbackHandler()
self.embedding = OpenAIEmbeddings(
base_url=base_url,
api_key=api_key,
model=embedding_model
)
self.llm = ChatOpenAI(
base_url=base_url,
api_key=api_key,
model=reasoning_model,
streaming=True,
callbacks=[self.response_callback]
)
self.log = Logger('audit', callback=self.process_output_callback)
self.splitter = CharacterTextSplitter(
chunk_size=300,
chunk_overlap=0,
separator=". ",
)
self.messages_db = FAISS.from_texts(['nothing'], self.embedding)
self.retriever = self.messages_db.as_retriever()
self.redundant_filter = EmbeddingsRedundantFilter(embeddings=self.embedding)
self.relevant_filter = EmbeddingsFilter(
embeddings=self.embedding,
similarity_threshold=0.76,
)
self.pipeline_compressor = DocumentCompressorPipeline(
transformers=[self.splitter, self.redundant_filter, self.relevant_filter]
)
self.compression_retriever = ContextualCompressionRetriever(
base_compressor=self.pipeline_compressor,
base_retriever=self.retriever,
)
self.prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=SYSTEM_PROMPT),
MessagesPlaceholder(variable_name='messages'),
('human', '{input}'),
])
def audit(self, event):
self.log.info('开始代码审计流程')
self.log.info(f'当前推理模型:{self.reasoning_model}')
self.log.info(f'当前嵌入模型:{self.embedding_model}')
input_content = ''
while True:
if event.is_set():
return
try:
result = self.send_message(input_content)
except Exception as e:
self.log.error(e)
return
if event.is_set():
return
if xml_match := re.search(xml_pattern, result, re.DOTALL):
try:
xml_content = xml_match.group(0)
xml_content = re.sub(
r'(<content>)(.*?)(</content>)',
r'\1<![CDATA[\2]]>\3',
xml_content,
flags=re.DOTALL
)
root = ET.fromstring(xml_content)
action = root.find('action').text
content = root.find('content').text
if content and content.startswith('<![CDATA[') and content.endswith(']]>'):
content = content[9:-3]
except Exception as e:
print(result)
print(e)
self.log.error(f'动作指令不合法,尝试纠正')
input_content = 'ILLEGAL OUTPUT'
continue
try:
if action == 'QUERY STRUCTURE':
self.log.info('请求查询项目结构')
input_content = self.print_tree(self.directory_tree)
self.store_messages_in_faiss(input_content)
continue
elif action == 'QUERY SOURCE':
self.log.info(f'请求查询源代码:{content}')
input_content = open(content, 'r', encoding='utf-8').read()
self.store_messages_in_faiss(input_content)
continue
elif action == 'OUTPUT RESULT':
self.log.warning('输出代码审计结果')
dict_content = eval(content)
json_content = json.loads(json.dumps(dict_content))
output_content = f'漏洞类型:{json_content["漏洞类型"]}\n漏洞文件:{json_content["漏洞文件"]}\n相关代码:\n{json_content["相关代码"]}\n修复建议:\n{json_content["修复建议"]}\n'
self.result_output_callback(output_content)
self.store_messages_in_faiss(output_content)
input_content = 'ok'
continue
elif action == 'FINISH TASK':
self.log.info('代码审计任务已完成')
return
else:
self.log.error(f'动作指令未定义:{action}')
return
except Exception as e:
self.log.error(e)
continue
def send_message(self, input_content):
self.response_callback.temp_content = ''
compressed_context = self.compression_retriever.invoke(input_content)
if input_content == '':
input_content = 'nothing'
input_dict = {
'input': input_content,
'context': compressed_context,
}
config_dict = {
'configurable': {'session_id': self.session_id}
}
self.raw_chain = self.prompt | self.llm
chain_with_message_history = RunnableWithMessageHistory(
self.raw_chain,
lambda session_id: self.chat_history,
input_messages_key='input',
history_messages_key='messages',
)
for _ in chain_with_message_history.stream(input_dict, config_dict):
pass
return self.response_callback.temp_content
def store_messages_in_faiss(self, message):
text_embedding = self.embedding.embed_query(message)
doc_id = str(uuid.uuid4())
self.messages_db.add_embeddings([(doc_id, text_embedding)], metadatas=[{"id": doc_id}])
def build_directory_tree(self, path, lang):
if lang in LANGUAGE:
suffixes = LANGUAGE[lang]
else:
self.log.error(f'不支持的语言:{lang}')
return
absolute_path = os.path.abspath(path).replace('\\', '/')
tree = {absolute_path: {}}
for root, _, files in os.walk(absolute_path):
relative_path = os.path.relpath(root, absolute_path)
current_node = tree[absolute_path]
if relative_path != '.':
parts = relative_path.split(os.sep)
for part in parts:
if part not in current_node:
current_node[part] = {}
current_node = current_node[part]
for suffix in suffixes:
lang_files = [file for file in files if file.endswith(suffix)]
if lang_files:
if 'files' not in current_node:
current_node['files'] = []
current_node['files'].extend(lang_files)
self.print_tree(tree)
self.directory_tree = tree
def format_tree(self, node, level=0):
result = []
indent = ' ' * level
for key, value in node.items():
if key == 'files':
for file in value:
result.append(f"{indent}- {file}")
else:
result.append(f"{indent}- {key}/")
if isinstance(value, dict):
result.extend(self.format_tree(value, level + 1))
return result
def print_tree(self, tree):
formatted_str = ''
formatted = self.format_tree(tree)
for line in formatted:
formatted_str += f"{line}\n"
# print(line)
return formatted_str

View File

@@ -1,13 +0,0 @@
from langchain_core.callbacks import BaseCallbackHandler
class CustomCallbackHandler(BaseCallbackHandler):
def __init__(self):
self.temp_content = ''
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.temp_content += token
def on_llm_end(self, response, **kwargs):
pass

View File

@@ -15,6 +15,7 @@ def resource_path(relative_path):
base_path = os.path.abspath(".") base_path = os.path.abspath(".")
return os.path.join(base_path, relative_path) return os.path.join(base_path, relative_path)
def main(): def main():
try: try:
app.exec() app.exec()

Binary file not shown.

87
utils/__init__.py Normal file
View File

@@ -0,0 +1,87 @@
import os
import logger
from app.constants import SUPPORT_LANGUAGE
def clean_tree(node, suffixes):
sub_keys_to_remove = []
for key, value in node.items():
if key == 'files':
continue
elif isinstance(value, dict):
clean_tree(value, suffixes)
has_valid_content = 'files' in value or any(isinstance(sub_val, dict) and sub_val for sub_val in value.values())
if not has_valid_content:
sub_keys_to_remove.append(key)
for key in sub_keys_to_remove:
del node[key]
if all(key == 'files' for key in node) and ('files' not in node or not node['files']):
node.clear()
def print_tree(tree):
formatted_str = ''
formatted = format_tree(tree)
for line in formatted:
formatted_str += f"{line}\n"
return formatted_str
def format_tree(node, level=0):
result = []
indent = ' ' * level
for key, value in node.items():
if key == 'files':
for file in value:
result.append(f"{indent}- {file}")
else:
result.append(f"{indent}- {key}/")
if isinstance(value, dict):
result.extend(format_tree(value, level + 1))
return result
def build_directory_tree(path, lang, callback):
log = logger.Logger('app', callback=callback)
if lang in SUPPORT_LANGUAGE:
suffixes = SUPPORT_LANGUAGE[lang]
else:
log.error(f'不支持的语言:{lang}')
return
absolute_path = os.path.abspath(path).replace('\\', '/')
src_files = []
tree = {absolute_path: {}}
for root, _, files in os.walk(absolute_path):
relative_path = os.path.relpath(root, absolute_path)
current_node = tree[absolute_path]
if relative_path != '.':
parts = relative_path.split(os.sep)
for part in parts:
if part not in current_node:
current_node[part] = {}
current_node = current_node[part]
for suffix in suffixes:
lang_files = [file for file in files if file.endswith(suffix)]
if lang_files:
if 'files' not in current_node:
current_node['files'] = []
current_node['files'].extend(lang_files)
src_files.extend(lang_files)
if len(src_files) > 0:
clean_tree(tree[absolute_path], suffixes)
return print_tree(tree)
else:
return None