mirror of
https://github.com/yv1ing/MollyAudit.git
synced 2025-09-16 14:55:50 +08:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 32d82c99ec | |||
| 5467f72bd7 | |||
| c8bb6270a1 | |||
| d4bf319725 |
19
README.md
19
README.md
@@ -1,5 +1,18 @@
|
|||||||
# MollyAudit
|

|
||||||
|
|
||||||
An automated code auditing tool powered by langchain.
|
# What's this?
|
||||||
|
|
||||||

|
This is an automatic code auditing tool based on LangChain and driven by LLMs.
|
||||||
|
|
||||||
|
**Basic idea:**
|
||||||
|
|
||||||
|
I designed two agents, one called Code Audit Engineer (CAE) and another called Code Software Architect (CSA). CSA is responsible for functionally dividing the project structure to be audited and then passing it to CAE as a reference.
|
||||||
|
|
||||||
|
After CAE collects the project structure, it independently determines the actions of the audit code and outputs the action instructions in a specific format. The local program parses the instructions and sends the relevant results, thus forming an automatic workflow.
|
||||||
|
|
||||||
|
At present, this tool is still in an immature stage and there are still many problems. If you have better suggestions, please contact me!
|
||||||
|
|
||||||
|
|
||||||
|
# Tool interface
|
||||||
|
|
||||||
|

|
||||||
143
agents/CAE/__init__.py
Normal file
143
agents/CAE/__init__.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""
|
||||||
|
===代码审计工程师===
|
||||||
|
用于分析具体的源代码,包括数据流、控制流等
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from langchain_core.messages import SystemMessage
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
||||||
|
from langchain_core.runnables import RunnableWithMessageHistory
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_community.chat_message_histories import ChatMessageHistory
|
||||||
|
from agents.CAE.prompt import CAE_SYSTEM_PROMPT, CAE_HUMAN_PROMPT
|
||||||
|
from logger import Logger
|
||||||
|
|
||||||
|
|
||||||
|
class CAE:
|
||||||
|
def __init__(self, base_url, api_key, model, process_output_callback):
|
||||||
|
# LLM配置
|
||||||
|
self.llm = ChatOpenAI(base_url=base_url, api_key=api_key, model=model)
|
||||||
|
self.session_id = uuid.uuid4().hex
|
||||||
|
|
||||||
|
# 内存记忆
|
||||||
|
self.max_history_length = 10
|
||||||
|
self.history = ChatMessageHistory()
|
||||||
|
|
||||||
|
# 提示词配置
|
||||||
|
self.system_prompt = CAE_SYSTEM_PROMPT
|
||||||
|
self.human_prompt = CAE_HUMAN_PROMPT
|
||||||
|
|
||||||
|
# 日志器配置
|
||||||
|
self.log = Logger(name='CAE', callback=process_output_callback)
|
||||||
|
|
||||||
|
def audit(self, project_structure, project_module_division, result_output_callback, event):
|
||||||
|
self.log.info('CAE开始审计项目代码')
|
||||||
|
|
||||||
|
# 提示词模板
|
||||||
|
self.llm_tmpl = ChatPromptTemplate.from_messages([
|
||||||
|
SystemMessage(content=self.system_prompt),
|
||||||
|
MessagesPlaceholder(variable_name='history'),
|
||||||
|
HumanMessagePromptTemplate.from_template(template=self.human_prompt),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 调用链配置
|
||||||
|
self.raw_chain = self.llm_tmpl | self.llm
|
||||||
|
self.llm_chain = RunnableWithMessageHistory(
|
||||||
|
self.raw_chain,
|
||||||
|
lambda session_id: self.history,
|
||||||
|
input_messages_key='content',
|
||||||
|
history_messages_key='history',
|
||||||
|
)
|
||||||
|
|
||||||
|
# 进入审计流程
|
||||||
|
input_content = 'continue'
|
||||||
|
while True:
|
||||||
|
if event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
# 剔除更早的对话
|
||||||
|
while len(self.history.messages) > self.max_history_length:
|
||||||
|
self.history.messages.pop(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取当前输出
|
||||||
|
input_dict = {
|
||||||
|
'content': input_content,
|
||||||
|
'history': self.history.messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
config_dict = {
|
||||||
|
'configurable': {'session_id': self.session_id}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = self.llm_chain.invoke(input_dict, config_dict)
|
||||||
|
if event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
# 解析动作指令
|
||||||
|
if xml_match := re.search(r'<root>.*?</root>', result.content, re.DOTALL):
|
||||||
|
try:
|
||||||
|
xml_content = xml_match.group(0)
|
||||||
|
xml_content = re.sub(
|
||||||
|
r'(<content>)(.*?)(</content>)',
|
||||||
|
r'\1<![CDATA[\2]]>\3',
|
||||||
|
xml_content,
|
||||||
|
flags=re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
root = ET.fromstring(xml_content)
|
||||||
|
action = root.find('action').text
|
||||||
|
content = root.find('content').text
|
||||||
|
if content and content.startswith('<![CDATA[') and content.endswith(']]>'):
|
||||||
|
content = content[9:-3]
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error(f'CAE动作指令不合法:尝试纠正')
|
||||||
|
input_content = 'ILLEGAL OUTPUT'
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 执行动作
|
||||||
|
try:
|
||||||
|
if action == 'QUERY STRUCTURE':
|
||||||
|
self.log.info('CAE请求查询项目结构')
|
||||||
|
input_content = project_structure
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif action == 'MODULE DIVISION':
|
||||||
|
self.log.info('CAE请求查询项目模块')
|
||||||
|
input_content = project_module_division
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif action == 'QUERY SOURCE':
|
||||||
|
self.log.info(f'CAE请求查询源代码:{content}')
|
||||||
|
try:
|
||||||
|
input_content = open(content, 'r', encoding='utf-8').read()
|
||||||
|
except Exception as e:
|
||||||
|
input_content = str(e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif action == 'OUTPUT RESULT':
|
||||||
|
self.log.warning('CAE输出代码审计结果')
|
||||||
|
dict_content = eval(content)
|
||||||
|
json_content = json.loads(json.dumps(dict_content))
|
||||||
|
output_content = f'漏洞类型:{json_content["漏洞类型"]}\n漏洞文件:{json_content["漏洞文件"]}\n相关代码:\n{json_content["相关代码"]}\n修复建议:\n{json_content["修复建议"]}\n'
|
||||||
|
result_output_callback(output_content)
|
||||||
|
input_content = 'continue'
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif action == 'FINISH TASK':
|
||||||
|
self.log.info('CAE完成项目代码审计')
|
||||||
|
return
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.log.error(f'CAE动作指令未定义:{action}')
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error(e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error(e)
|
||||||
|
continue
|
||||||
95
agents/CAE/prompt.py
Normal file
95
agents/CAE/prompt.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
CAE_SYSTEM_PROMPT = """
|
||||||
|
You are a professional code audit security expert, responsible for helping users audit possible vulnerabilities and security issues in source code.
|
||||||
|
You will perform code audits according to the following process:
|
||||||
|
1. Query project structure
|
||||||
|
You input the action command in the following format, and the user will send you the project structure below:
|
||||||
|
<root>
|
||||||
|
<action>QUERY STRUCTURE</action>
|
||||||
|
<content></content>
|
||||||
|
</root>
|
||||||
|
|
||||||
|
2. Query module division
|
||||||
|
You input the action command in the following format, and the user will send you the project module division:
|
||||||
|
<root>
|
||||||
|
<action>MODULE DIVISION</action>
|
||||||
|
<content></content>
|
||||||
|
</root>
|
||||||
|
|
||||||
|
3. Query the source code
|
||||||
|
You input the action command in the following format, and the user will send you the source code you need below:
|
||||||
|
<root>
|
||||||
|
<action>QUERY SOURCE</action>
|
||||||
|
<content>the absolute path of the file you want to query</content>
|
||||||
|
</root>
|
||||||
|
|
||||||
|
4. Output code audit results
|
||||||
|
You input the code audit results in the following format, and the user will send you "continue", then you can proceed to the next step of the audit:
|
||||||
|
<root>
|
||||||
|
<action>OUTPUT RESULT</action>
|
||||||
|
<content>the audit results you want to output</content>
|
||||||
|
</root>
|
||||||
|
|
||||||
|
5. Finish audit task
|
||||||
|
When you are sure that all source code files have been audited, you can output the action instructions to end the task in the following format:
|
||||||
|
<root>
|
||||||
|
<action>FINISH TASK</action>
|
||||||
|
<content></content>
|
||||||
|
</root>
|
||||||
|
|
||||||
|
Emphasis:
|
||||||
|
1. The part wrapped in square brackets [] is what you need to fill in according to the actual situation, do not use square brackets when outputting;
|
||||||
|
2. All your output can only be one of the 5 actions mentioned above. Any other form of output is strictly prohibited;
|
||||||
|
3. Only output audit results with vulnerabilities, and prohibit output without vulnerabilities!
|
||||||
|
4. During the audit process, you need to divide the information according to the provided modules and carefully analyze the control flow and data flow of the program. In this process, you can query the contents of multiple files. Remember to give the absolute path of the file to be queried in the format.
|
||||||
|
5. The audit task can be completed only after all source code files have been audited;
|
||||||
|
|
||||||
|
Some additional information, which are some specifications when you perform actions:
|
||||||
|
1. The project structure format sent to you is as follows. You need to construct the complete absolute path of the file you want to query based on these hierarchical relationships:
|
||||||
|
- C:/Users/yvling/Desktop/test/
|
||||||
|
- dir_1/
|
||||||
|
- 1.php
|
||||||
|
- dir_2/
|
||||||
|
- 2.php
|
||||||
|
- dir_3/
|
||||||
|
- 3.php
|
||||||
|
|
||||||
|
2. The project module division format provided by the user is as follows, you can use this as the basis for preliminary code audit:
|
||||||
|
HelloWorld Functional division
|
||||||
|
1 Configuration
|
||||||
|
- Package name: com.best.hello.config
|
||||||
|
- Main function: Web application configuration, including MVC and login interceptor.
|
||||||
|
- Absolute file path: C:/Users/yvling/Desktop/HelloWorld/src/main/java/com/best/hello/config
|
||||||
|
2 Controller
|
||||||
|
- Package name: com.best.hello.controller
|
||||||
|
- Main function: Demonstrating various common web security vulnerabilities through different controllers.
|
||||||
|
- Absolute file path: C:/Users/yvling/Desktop/HelloWorld/src/main/java/com/best/hello/controller
|
||||||
|
|
||||||
|
3. When you output the code audit results, you must use Chinese output and follow the following format(Python dict):
|
||||||
|
{'漏洞类型': 'SQL Injection', '漏洞文件': 'main.java', '相关代码': '```java\nString id=request.getParameter("id");\nres = st.executeQuery("SELECT* FROM\"IWEBSEC\".\"user\" WHERE \"id\"="+id);\n```', '修复建议': 'your suggestions...'}
|
||||||
|
|
||||||
|
Some Mandatory regulations:
|
||||||
|
1. Output Format:
|
||||||
|
a. Strictly use the predefined XML tag structure
|
||||||
|
b. Any Markdown symbols are not allowed
|
||||||
|
c. No line breaks in the content field
|
||||||
|
d. Do not use quotation marks around the output
|
||||||
|
2. Language Standards:
|
||||||
|
a. Technical terms are kept in their original English
|
||||||
|
b. Vulnerability descriptions must be in Chinese
|
||||||
|
3. Interaction restrictions:
|
||||||
|
a. Any content outside the output process is prohibited
|
||||||
|
b. Autonomously advance the audit process when receiving "continue", such as QUERY SOURCE
|
||||||
|
c. Vulnerabilities must be output immediately
|
||||||
|
4. Error handling:
|
||||||
|
a. When receiving the "ILLEGAL OUTPUT" prompt, terminate the current output immediately and recheck the format specification before continuing
|
||||||
|
5. Priority logic:
|
||||||
|
a. Entry file > Configuration file > Tool file
|
||||||
|
b. High-risk vulnerabilities (such as injection and RCE) are handled first
|
||||||
|
c. If multiple vulnerabilities are found in the same file, they need to be output multiple times
|
||||||
|
d. For vulnerabilities that may span files, the audit can only begin after the relevant files have been queried as needed
|
||||||
|
e. Only output audit results with vulnerabilities, and prohibit output without vulnerabilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
CAE_HUMAN_PROMPT = """
|
||||||
|
{content}
|
||||||
|
"""
|
||||||
48
agents/CSA/__init__.py
Normal file
48
agents/CSA/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""
|
||||||
|
===软件架构师===
|
||||||
|
用于分析项目的整体框架,抽取出清晰的项目结构和功能划分
|
||||||
|
"""
|
||||||
|
from langchain_core.messages import SystemMessage
|
||||||
|
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_community.callbacks.manager import get_openai_callback
|
||||||
|
from agents.CSA.prompt import CSA_SYSTEM_PROMPT, CSA_HUMAN_PROMPT
|
||||||
|
from logger import Logger
|
||||||
|
|
||||||
|
|
||||||
|
class CSA:
|
||||||
|
def __init__(self, base_url, api_key, model, process_output_callback):
|
||||||
|
# LLM配置
|
||||||
|
self.llm = ChatOpenAI(base_url=base_url, api_key=api_key, model=model)
|
||||||
|
|
||||||
|
# 提示词配置
|
||||||
|
self.system_prompt = CSA_SYSTEM_PROMPT
|
||||||
|
self.human_prompt = CSA_HUMAN_PROMPT
|
||||||
|
|
||||||
|
# 日志器配置
|
||||||
|
self.log = Logger(name='CSA', callback=process_output_callback)
|
||||||
|
|
||||||
|
|
||||||
|
def analyse(self, project_structure):
|
||||||
|
self.log.info('CSA开始分析项目模块')
|
||||||
|
|
||||||
|
# 提示词模板
|
||||||
|
self.llm_tmpl = ChatPromptTemplate.from_messages([
|
||||||
|
SystemMessage(content=self.system_prompt),
|
||||||
|
HumanMessagePromptTemplate.from_template(template=self.human_prompt),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 调用链配置
|
||||||
|
self.llm_chain = self.llm_tmpl | self.llm
|
||||||
|
|
||||||
|
# 获取分析结果
|
||||||
|
with get_openai_callback() as cb:
|
||||||
|
result = self.llm_chain.invoke({'project_structure': project_structure})
|
||||||
|
|
||||||
|
# TODO: 接入token用量统计
|
||||||
|
# print(f"请求消耗的输入 token 数: {cb.prompt_tokens}")
|
||||||
|
# print(f"请求消耗的输出 token 数: {cb.completion_tokens}")
|
||||||
|
# print(f"请求总共消耗的 token 数: {cb.total_tokens}")
|
||||||
|
|
||||||
|
self.log.info('CSA完成分析项目模块')
|
||||||
|
return result.content
|
||||||
32
agents/CSA/prompt.py
Normal file
32
agents/CSA/prompt.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
CSA_SYSTEM_PROMPT = """
|
||||||
|
You are a senior software architect, your responsibilities are:
|
||||||
|
The user provides you with the directory structure of the project, you need to analyze the project, summarize the project function division, and output the results in the following format (Markdown):
|
||||||
|
[Project name] Functional division
|
||||||
|
[Module number] [Module name]
|
||||||
|
- Package name: [Package name]
|
||||||
|
- Main function: [Main function]
|
||||||
|
- Absolute file path: [Absolute file path]
|
||||||
|
|
||||||
|
Emphasis:
|
||||||
|
1. The part wrapped in square brackets [] is what you need to fill in according to the actual situation, do not use square brackets when outputting;
|
||||||
|
2. One package (directory) uses one number;
|
||||||
|
3. The output absolute path refers to the absolute path of the source code file. All source code files in the same directory must be output;
|
||||||
|
|
||||||
|
For example:
|
||||||
|
HelloWorld Functional division
|
||||||
|
1 Configuration
|
||||||
|
- Package name: com.example.hello.config
|
||||||
|
- Main function: Web application configuration, including MVC and login interceptor.
|
||||||
|
- Absolute file path: C:/Users/yvling/Desktop/HelloWorld/src/main/java/com/example/hello/config.java
|
||||||
|
2 Controller
|
||||||
|
- Package name: com.example.hello.controller
|
||||||
|
- Main function: Demonstrating various common web security vulnerabilities through different controllers.
|
||||||
|
- Absolute file path: C:/Users/yvling/Desktop/HelloWorld/src/main/java/com/example/hello/controller.java
|
||||||
|
"""
|
||||||
|
|
||||||
|
CSA_HUMAN_PROMPT = """
|
||||||
|
The project directory structure provided by the user is as follows:
|
||||||
|
{project_structure}
|
||||||
|
|
||||||
|
Please start the analysis and output according to the format.
|
||||||
|
"""
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from audit import Audit
|
import logger
|
||||||
|
from agents.CAE import CAE
|
||||||
|
from agents.CSA import CSA
|
||||||
|
from utils import build_directory_tree
|
||||||
|
|
||||||
warnings.simplefilter('ignore', FutureWarning)
|
warnings.simplefilter('ignore', FutureWarning)
|
||||||
|
|
||||||
@@ -11,8 +14,8 @@ config_file_path = os.path.join(home_dir, config_file_name)
|
|||||||
GLOBAL_CONFIG = {
|
GLOBAL_CONFIG = {
|
||||||
"base_url": "https://openai.com/v1",
|
"base_url": "https://openai.com/v1",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"reasoning_model": "o3-mini-all",
|
"csa_model": "gpt-4o",
|
||||||
"embedding_model": "text-embedding-3-small"
|
"cae_model": "o3-mini-all",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +35,7 @@ def load_config():
|
|||||||
file.write(f"{key}={value}\n")
|
file.write(f"{key}={value}\n")
|
||||||
|
|
||||||
|
|
||||||
def update_config(key, value):
|
def real_update_config(key, value):
|
||||||
global GLOBAL_CONFIG
|
global GLOBAL_CONFIG
|
||||||
|
|
||||||
GLOBAL_CONFIG[key] = value
|
GLOBAL_CONFIG[key] = value
|
||||||
@@ -41,8 +44,18 @@ def update_config(key, value):
|
|||||||
file.write(f"{k}={v}\n")
|
file.write(f"{k}={v}\n")
|
||||||
|
|
||||||
|
|
||||||
def audit_code(base_url, api_key, src_root, language, reasoning_model, embedding_model, process_output_callback,
|
def audit_code(base_url, api_key, src_root, language, cae_model, csa_model, process_output_callback,
|
||||||
result_output_callback, event):
|
result_output_callback, event):
|
||||||
audit = Audit(base_url, api_key, reasoning_model, embedding_model, process_output_callback, result_output_callback)
|
log = logger.Logger('app', process_output_callback)
|
||||||
audit.load_source_files(src_root, language)
|
|
||||||
audit.audit(event)
|
csa = CSA(base_url=base_url, api_key=api_key, model=csa_model, process_output_callback=process_output_callback)
|
||||||
|
cae = CAE(base_url=base_url, api_key=api_key, model=cae_model, process_output_callback=process_output_callback)
|
||||||
|
|
||||||
|
project_structure = build_directory_tree(src_root, language, process_output_callback)
|
||||||
|
if project_structure is None:
|
||||||
|
log.error('未找到源代码文件')
|
||||||
|
return
|
||||||
|
|
||||||
|
project_module_division = csa.analyse(project_structure=project_structure)
|
||||||
|
cae.audit(project_structure=project_structure, project_module_division=project_module_division, event=event,
|
||||||
|
result_output_callback=result_output_callback)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
LANGUAGE = {
|
SUPPORT_LANGUAGE = {
|
||||||
'c': ['.c'],
|
'c': ['.c'],
|
||||||
'cpp': ['.cc', '.cpp'],
|
'cpp': ['.cc', '.cpp'],
|
||||||
'go': ['.go'],
|
'go': ['.go'],
|
||||||
59
app/ui.py
59
app/ui.py
@@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import datetime
|
||||||
import threading
|
import threading
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from app import audit_code, update_config, GLOBAL_CONFIG
|
from app import audit_code, real_update_config, GLOBAL_CONFIG
|
||||||
from app.utils import get_now_date
|
|
||||||
from logger import Logger
|
from logger import Logger
|
||||||
from PyQt6.QtGui import QColor, QGuiApplication, QTextCursor
|
from PyQt6.QtGui import QColor, QGuiApplication, QTextCursor
|
||||||
from PyQt6.QtWidgets import (
|
from PyQt6.QtWidgets import (
|
||||||
@@ -18,7 +18,6 @@ from PyQt6.QtWidgets import (
|
|||||||
QComboBox
|
QComboBox
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
BACKGROUND_COLOR = '#dcdcdc'
|
BACKGROUND_COLOR = '#dcdcdc'
|
||||||
ANSI_ESCAPE = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
|
ANSI_ESCAPE = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
|
||||||
ANSI_COLOR_REGEX = re.compile(r'\x1B\[(?:([0-9]+);)?([0-9]+)m')
|
ANSI_COLOR_REGEX = re.compile(r'\x1B\[(?:([0-9]+);)?([0-9]+)m')
|
||||||
@@ -31,6 +30,12 @@ ANSI_COLOR_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_now_date():
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
formatted = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
|
||||||
def convert_ansi_to_rich_text(text):
|
def convert_ansi_to_rich_text(text):
|
||||||
segments = []
|
segments = []
|
||||||
pos = 0
|
pos = 0
|
||||||
@@ -57,7 +62,8 @@ def convert_ansi_to_rich_text(text):
|
|||||||
class MainWindow(QWidget):
|
class MainWindow(QWidget):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.event = Event()
|
self.event = Event()
|
||||||
self.log = Logger('ui', callback=self.process_output_callback)
|
self.log = Logger('app', callback=self.process_output_callback)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.init_ui()
|
self.init_ui()
|
||||||
|
|
||||||
@@ -100,14 +106,14 @@ class MainWindow(QWidget):
|
|||||||
main_layout.addLayout(config_layout)
|
main_layout.addLayout(config_layout)
|
||||||
|
|
||||||
model_layout = QHBoxLayout()
|
model_layout = QHBoxLayout()
|
||||||
self.reasoning_model_label = QLabel('推理模型:')
|
self.cae_model_label = QLabel('模块分析模型:')
|
||||||
self.reasoning_model_input = QLineEdit()
|
self.cae_model_input = QLineEdit()
|
||||||
self.embedding_model_label = QLabel('嵌入模型:')
|
self.csa_model_label = QLabel('代码审计模型:')
|
||||||
self.embedding_model_input = QLineEdit()
|
self.csa_model_input = QLineEdit()
|
||||||
model_layout.addWidget(self.reasoning_model_label)
|
model_layout.addWidget(self.cae_model_label)
|
||||||
model_layout.addWidget(self.reasoning_model_input)
|
model_layout.addWidget(self.cae_model_input)
|
||||||
model_layout.addWidget(self.embedding_model_label)
|
model_layout.addWidget(self.csa_model_label)
|
||||||
model_layout.addWidget(self.embedding_model_input)
|
model_layout.addWidget(self.csa_model_input)
|
||||||
main_layout.addLayout(model_layout)
|
main_layout.addLayout(model_layout)
|
||||||
|
|
||||||
# 按钮部分
|
# 按钮部分
|
||||||
@@ -156,7 +162,9 @@ class MainWindow(QWidget):
|
|||||||
|
|
||||||
# 导出结果
|
# 导出结果
|
||||||
export_button_layout = QHBoxLayout()
|
export_button_layout = QHBoxLayout()
|
||||||
link_label = QLabel('联系作者:<a href="https://github.com/yv1ing">Github</a> <a href=mailto:me@yvling.cn>Email</a>')
|
link_label = QLabel(
|
||||||
|
'联系作者:<a href="https://github.com/yv1ing">Github</a> <a href=mailto:me@yvling.cn>Email</a>'
|
||||||
|
)
|
||||||
link_label.setOpenExternalLinks(True)
|
link_label.setOpenExternalLinks(True)
|
||||||
export_button_layout.addWidget(link_label)
|
export_button_layout.addWidget(link_label)
|
||||||
export_button_layout.addStretch(1)
|
export_button_layout.addStretch(1)
|
||||||
@@ -169,8 +177,8 @@ class MainWindow(QWidget):
|
|||||||
# 加载配置
|
# 加载配置
|
||||||
self.base_url_input.setText(GLOBAL_CONFIG['base_url'])
|
self.base_url_input.setText(GLOBAL_CONFIG['base_url'])
|
||||||
self.api_key_input.setText(GLOBAL_CONFIG['api_key'])
|
self.api_key_input.setText(GLOBAL_CONFIG['api_key'])
|
||||||
self.reasoning_model_input.setText(GLOBAL_CONFIG['reasoning_model'])
|
self.cae_model_input.setText(GLOBAL_CONFIG['cae_model'])
|
||||||
self.embedding_model_input.setText(GLOBAL_CONFIG['embedding_model'])
|
self.csa_model_input.setText(GLOBAL_CONFIG['csa_model'])
|
||||||
|
|
||||||
def closeEvent(self, event):
|
def closeEvent(self, event):
|
||||||
self.event.set()
|
self.event.set()
|
||||||
@@ -182,14 +190,13 @@ class MainWindow(QWidget):
|
|||||||
def update_config(self):
|
def update_config(self):
|
||||||
base_url = self.base_url_input.text()
|
base_url = self.base_url_input.text()
|
||||||
api_key = self.api_key_input.text()
|
api_key = self.api_key_input.text()
|
||||||
reasoning_model = self.reasoning_model_input.text()
|
cae_model = self.cae_model_input.text()
|
||||||
embedding_model = self.embedding_model_input.text()
|
csa_model = self.csa_model_input.text()
|
||||||
|
|
||||||
update_config('base_url', base_url)
|
|
||||||
update_config('api_key', api_key)
|
|
||||||
update_config('reasoning_model', reasoning_model)
|
|
||||||
update_config('embedding_model', embedding_model)
|
|
||||||
|
|
||||||
|
real_update_config('base_url', base_url)
|
||||||
|
real_update_config('api_key', api_key)
|
||||||
|
real_update_config('cae_model', cae_model)
|
||||||
|
real_update_config('csa_model', csa_model)
|
||||||
self.log.info('更新配置成功')
|
self.log.info('更新配置成功')
|
||||||
|
|
||||||
def select_directory(self):
|
def select_directory(self):
|
||||||
@@ -234,8 +241,8 @@ class MainWindow(QWidget):
|
|||||||
selected_lang = self.lang_combobox.currentText()
|
selected_lang = self.lang_combobox.currentText()
|
||||||
base_url = self.base_url_input.text()
|
base_url = self.base_url_input.text()
|
||||||
api_key = self.api_key_input.text()
|
api_key = self.api_key_input.text()
|
||||||
reasoning_model = self.reasoning_model_input.text()
|
cae_model = self.cae_model_input.text()
|
||||||
embedding_model = self.embedding_model_input.text()
|
csa_model = self.csa_model_input.text()
|
||||||
|
|
||||||
if not selected_dir or not base_url or not api_key:
|
if not selected_dir or not base_url or not api_key:
|
||||||
self.log.error('请确保项目目录、接口地址和模型密钥等都已填写')
|
self.log.error('请确保项目目录、接口地址和模型密钥等都已填写')
|
||||||
@@ -250,8 +257,8 @@ class MainWindow(QWidget):
|
|||||||
api_key,
|
api_key,
|
||||||
selected_dir,
|
selected_dir,
|
||||||
selected_lang,
|
selected_lang,
|
||||||
reasoning_model,
|
cae_model,
|
||||||
embedding_model,
|
csa_model,
|
||||||
self.process_output_callback,
|
self.process_output_callback,
|
||||||
self.result_output_callback,
|
self.result_output_callback,
|
||||||
self.event
|
self.event
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def get_now_date():
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
formatted = now.strftime("%Y-%m-%d-%H-%M-%S")
|
|
||||||
return formatted
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 85 KiB After Width: | Height: | Size: 50 KiB |
@@ -1,183 +0,0 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
|
||||||
from langchain_core.messages import SystemMessage
|
|
||||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
||||||
from langchain_community.vectorstores import FAISS
|
|
||||||
from langchain_community.chat_message_histories import ChatMessageHistory
|
|
||||||
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
|
||||||
from langchain.retrievers import ContextualCompressionRetriever
|
|
||||||
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
|
|
||||||
from langchain_text_splitters import CharacterTextSplitter
|
|
||||||
|
|
||||||
from audit.rules import FROTIFY_RULES
|
|
||||||
from logger import Logger
|
|
||||||
from audit import callback
|
|
||||||
from audit.prompt import SYSTEM_PROMPT
|
|
||||||
from audit.language import LANGUAGE
|
|
||||||
|
|
||||||
xml_pattern = r'<root>.*?</root>'
|
|
||||||
|
|
||||||
|
|
||||||
class Audit:
|
|
||||||
def __init__(self, base_url, api_key, reasoning_model, embedding_model, process_output_callback, result_output_callback):
|
|
||||||
self.raw_chain = None
|
|
||||||
self.source_files_list = []
|
|
||||||
self.reasoning_model = reasoning_model
|
|
||||||
self.embedding_model = embedding_model
|
|
||||||
self.fortify_rules = FROTIFY_RULES
|
|
||||||
self.process_output_callback = process_output_callback
|
|
||||||
self.result_output_callback = result_output_callback
|
|
||||||
self.chat_history = ChatMessageHistory()
|
|
||||||
self.session_id = uuid.uuid4().hex
|
|
||||||
self.response_callback = callback.CustomCallbackHandler()
|
|
||||||
self.embedding = OpenAIEmbeddings(
|
|
||||||
base_url=base_url,
|
|
||||||
api_key=api_key,
|
|
||||||
model=embedding_model
|
|
||||||
)
|
|
||||||
self.llm = ChatOpenAI(
|
|
||||||
base_url=base_url,
|
|
||||||
api_key=api_key,
|
|
||||||
model=reasoning_model,
|
|
||||||
streaming=True,
|
|
||||||
callbacks=[self.response_callback]
|
|
||||||
)
|
|
||||||
self.log = Logger('audit', callback=self.process_output_callback)
|
|
||||||
self.splitter = CharacterTextSplitter(
|
|
||||||
chunk_size=300,
|
|
||||||
chunk_overlap=0,
|
|
||||||
separator=". ",
|
|
||||||
)
|
|
||||||
self.messages_db = FAISS.from_texts(['nothing'], self.embedding)
|
|
||||||
self.retriever = self.messages_db.as_retriever()
|
|
||||||
self.redundant_filter = EmbeddingsRedundantFilter(embeddings=self.embedding)
|
|
||||||
self.relevant_filter = EmbeddingsFilter(
|
|
||||||
embeddings=self.embedding,
|
|
||||||
similarity_threshold=0.76,
|
|
||||||
)
|
|
||||||
self.pipeline_compressor = DocumentCompressorPipeline(
|
|
||||||
transformers=[self.splitter, self.redundant_filter, self.relevant_filter]
|
|
||||||
)
|
|
||||||
self.compression_retriever = ContextualCompressionRetriever(
|
|
||||||
base_compressor=self.pipeline_compressor,
|
|
||||||
base_retriever=self.retriever,
|
|
||||||
)
|
|
||||||
self.prompt = ChatPromptTemplate.from_messages([
|
|
||||||
SystemMessage(content=SYSTEM_PROMPT),
|
|
||||||
MessagesPlaceholder(variable_name='messages'),
|
|
||||||
('human', '{input}'),
|
|
||||||
])
|
|
||||||
|
|
||||||
def audit(self, event):
|
|
||||||
if len(self.source_files_list) <= 0:
|
|
||||||
self.log.error('没有找到源代码文件')
|
|
||||||
return
|
|
||||||
|
|
||||||
self.log.info('开始代码审计流程')
|
|
||||||
self.log.info(f'当前推理模型:{self.reasoning_model}')
|
|
||||||
self.log.info(f'当前嵌入模型:{self.embedding_model}')
|
|
||||||
|
|
||||||
input_content = ''
|
|
||||||
while True:
|
|
||||||
if event.is_set():
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = self.send_message(input_content)
|
|
||||||
except Exception as e:
|
|
||||||
self.log.error(e)
|
|
||||||
return
|
|
||||||
|
|
||||||
if event.is_set():
|
|
||||||
return
|
|
||||||
|
|
||||||
if xml_match := re.search(xml_pattern, result, re.DOTALL):
|
|
||||||
try:
|
|
||||||
xml_content = xml_match.group(0)
|
|
||||||
root = ET.fromstring(xml_content)
|
|
||||||
|
|
||||||
action = root.find('action').text
|
|
||||||
content = root.find('content').text
|
|
||||||
except Exception as e:
|
|
||||||
print(result)
|
|
||||||
print(e)
|
|
||||||
self.log.error(f'动作指令不合法,尝试纠正')
|
|
||||||
input_content = 'ILLEGAL OUTPUT'
|
|
||||||
continue
|
|
||||||
|
|
||||||
if action == 'QUERY STRUCTURE':
|
|
||||||
self.log.info('请求查询项目结构')
|
|
||||||
input_content = '\n'.join(x for x in self.source_files_list)
|
|
||||||
continue
|
|
||||||
elif action == 'QUERY SOURCE':
|
|
||||||
self.log.info(f'请求查询源代码:{content}')
|
|
||||||
input_content = open(content, 'r', encoding='utf-8').read()
|
|
||||||
continue
|
|
||||||
elif action == 'QUERY FORTIFY':
|
|
||||||
self.log.info(f'请求查询规则库:{content}')
|
|
||||||
input_content = '\n'.join(x for x in self.fortify_rules if x == content)
|
|
||||||
continue
|
|
||||||
elif action == 'OUTPUT RESULT':
|
|
||||||
self.log.warning('输出代码审计结果')
|
|
||||||
self.result_output_callback(content)
|
|
||||||
self.store_messages_in_faiss(content)
|
|
||||||
input_content = 'ok'
|
|
||||||
continue
|
|
||||||
elif action == 'FINISH TASK':
|
|
||||||
self.log.info('代码审计任务已完成')
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.log.error(f'动作指令未定义:{action}')
|
|
||||||
return
|
|
||||||
|
|
||||||
def send_message(self, input_content):
|
|
||||||
self.response_callback.temp_content = ''
|
|
||||||
compressed_context = self.compression_retriever.invoke(input_content)
|
|
||||||
|
|
||||||
if input_content == '':
|
|
||||||
input_content = 'nothing'
|
|
||||||
|
|
||||||
input_dict = {
|
|
||||||
'input': input_content,
|
|
||||||
'context': compressed_context,
|
|
||||||
}
|
|
||||||
config_dict = {
|
|
||||||
'configurable': {'session_id': self.session_id}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.raw_chain = self.prompt | self.llm
|
|
||||||
chain_with_message_history = RunnableWithMessageHistory(
|
|
||||||
self.raw_chain,
|
|
||||||
lambda session_id: self.chat_history,
|
|
||||||
input_messages_key='input',
|
|
||||||
history_messages_key='messages',
|
|
||||||
)
|
|
||||||
|
|
||||||
for _ in chain_with_message_history.stream(input_dict, config_dict):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return self.response_callback.temp_content
|
|
||||||
|
|
||||||
def store_messages_in_faiss(self, message):
|
|
||||||
text_embedding = self.embedding.embed_query(message)
|
|
||||||
doc_id = str(uuid.uuid4())
|
|
||||||
self.messages_db.add_embeddings([(doc_id, text_embedding)], metadatas=[{"id": doc_id}])
|
|
||||||
self.log.info(f"代码审计结果已缓存,文档编号:{doc_id}")
|
|
||||||
|
|
||||||
def load_source_files(self, path, lang):
|
|
||||||
if lang in LANGUAGE:
|
|
||||||
suffixes = LANGUAGE[lang]
|
|
||||||
else:
|
|
||||||
self.log.error('不支持的编程语言')
|
|
||||||
return
|
|
||||||
|
|
||||||
for root, _, files in os.walk(path):
|
|
||||||
self.source_files_list.extend(
|
|
||||||
os.path.join(root, file).replace('\\', '/') for file in files if any(file.endswith(suffix) for suffix in suffixes)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log.info(f'源代码文件加载完成,共:{len(self.source_files_list)} 个')
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
from langchain_core.callbacks import BaseCallbackHandler
|
|
||||||
|
|
||||||
|
|
||||||
class CustomCallbackHandler(BaseCallbackHandler):
|
|
||||||
def __init__(self):
|
|
||||||
self.temp_content = ''
|
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
|
||||||
self.temp_content += token
|
|
||||||
|
|
||||||
def on_llm_end(self, response, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
SYSTEM_PROMPT = """
|
|
||||||
You are a professional code audit security expert, responsible for helping users audit possible vulnerabilities and security issues in source code.
|
|
||||||
You will perform code audits according to the following process:
|
|
||||||
|
|
||||||
1. Query project structure
|
|
||||||
You input the action command in the following format, and the user will send you the absolute path of all source files in the project below:
|
|
||||||
<root>
|
|
||||||
<action>QUERY STRUCTURE</action>
|
|
||||||
<content></content>
|
|
||||||
</root>
|
|
||||||
|
|
||||||
2. Query the vulnerability detection rule base
|
|
||||||
You input the action instructions in the following format, and the user will send you the vulnerability detection rule library extracted from Fortify as a reference for your code audit:
|
|
||||||
<root>
|
|
||||||
<action>QUERY FORTIFY</action>
|
|
||||||
<content>The language you want to query, options are: c, cpp, go, php, jsp, java, python, javascript</content>
|
|
||||||
</root>
|
|
||||||
|
|
||||||
3. Query the source code
|
|
||||||
You input the action command in the following format, and the user will send you the source code you need below:
|
|
||||||
<root>
|
|
||||||
<action>QUERY SOURCE</action>
|
|
||||||
<content>the absolute path of the file you want to query</content>
|
|
||||||
</root>
|
|
||||||
|
|
||||||
4. Output code audit results
|
|
||||||
You input the code audit results in the following format, and the user will send you "ok", then you can proceed to the next step of the audit:
|
|
||||||
<root>
|
|
||||||
<action>OUTPUT RESULT</action>
|
|
||||||
<content>the audit results you want to output</content>
|
|
||||||
</root>
|
|
||||||
|
|
||||||
5. Finish audit task
|
|
||||||
When you are sure that all source code files have been audited, you can output the action instructions to end the task in the following format:
|
|
||||||
<root>
|
|
||||||
<action>FINISH TASK</action>
|
|
||||||
<content></content>
|
|
||||||
</root>
|
|
||||||
|
|
||||||
All your output can only be one of the five actions mentioned above. Any other form of output is strictly prohibited.
|
|
||||||
|
|
||||||
|
|
||||||
Some additional information, which are some specifications when you perform actions:
|
|
||||||
1. The format of the vulnerability detection rule base provided to you is as follows:
|
|
||||||
{
|
|
||||||
'language':
|
|
||||||
'vuln_kingdom':
|
|
||||||
'vuln_category':
|
|
||||||
}
|
|
||||||
|
|
||||||
2. When you output the code audit results, you must use Chinese output and follow the following format:
|
|
||||||
漏洞类型:
|
|
||||||
漏洞文件:
|
|
||||||
相关代码:
|
|
||||||
修复建议:
|
|
||||||
|
|
||||||
Some Mandatory regulations:
|
|
||||||
1. Output Format:
|
|
||||||
a. Strictly use the predefined XML tag structure
|
|
||||||
b. Any Markdown symbols are not allowed
|
|
||||||
c. No line breaks in the content field
|
|
||||||
2. Language Standards:
|
|
||||||
a. Technical terms are kept in their original English
|
|
||||||
b. Vulnerability descriptions must be in Chinese
|
|
||||||
3. Interaction restrictions:
|
|
||||||
a. Any content outside the output process is prohibited
|
|
||||||
b. Autonomously advance the audit process when receiving "nothing" or "ok"
|
|
||||||
c. Vulnerabilities must be output immediately
|
|
||||||
4. Error handling:
|
|
||||||
a. When receiving the "ILLEGAL OUTPUT" prompt, terminate the current output immediately and recheck the format specification before continuing
|
|
||||||
5. Priority logic:
|
|
||||||
a. Entry file > Configuration file > Tool file
|
|
||||||
b. High-risk vulnerabilities (such as injection and RCE) are handled first
|
|
||||||
c. If multiple vulnerabilities are found in the same file, they need to be output multiple times
|
|
||||||
d. For vulnerabilities that may span files, the audit can only begin after the relevant files have been queried as needed
|
|
||||||
"""
|
|
||||||
10845
audit/rules.py
10845
audit/rules.py
File diff suppressed because one or more lines are too long
1
main.py
1
main.py
@@ -15,6 +15,7 @@ def resource_path(relative_path):
|
|||||||
base_path = os.path.abspath(".")
|
base_path = os.path.abspath(".")
|
||||||
return os.path.join(base_path, relative_path)
|
return os.path.join(base_path, relative_path)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
try:
|
try:
|
||||||
app.exec()
|
app.exec()
|
||||||
|
|||||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
@@ -1,44 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
|
|
||||||
|
|
||||||
rules_list = []
|
|
||||||
|
|
||||||
def extract_rules(xml_file):
|
|
||||||
tree = ET.parse(xml_file)
|
|
||||||
root = tree.getroot()
|
|
||||||
|
|
||||||
rules = root.findall('.//{xmlns://www.fortifysoftware.com/schema/rules}StructuralRule')
|
|
||||||
|
|
||||||
for rule in rules:
|
|
||||||
rule_info = {}
|
|
||||||
|
|
||||||
vuln_kingdom = rule.find('{xmlns://www.fortifysoftware.com/schema/rules}VulnKingdom')
|
|
||||||
vuln_category = rule.find('{xmlns://www.fortifysoftware.com/schema/rules}VulnCategory')
|
|
||||||
vuln_subcategory = rule.find('{xmlns://www.fortifysoftware.com/schema/rules}VulnSubcategory')
|
|
||||||
predicate = rule.find('{xmlns://www.fortifysoftware.com/schema/rules}Predicate')
|
|
||||||
|
|
||||||
rule_info['language'] = rule.get('language')
|
|
||||||
|
|
||||||
if rule_info['language'] in ['c', 'cpp', 'go', 'php', 'jsp', 'java', 'python', 'javascript']:
|
|
||||||
rule_info['vuln_kingdom'] = vuln_kingdom.text.replace(' ', ' ') if vuln_kingdom is not None else None
|
|
||||||
rule_info['vuln_category'] = vuln_category.text.replace(' ', ' ') if vuln_category is not None else None
|
|
||||||
rule_info['vuln_subcategory'] = vuln_subcategory.text.replace(' ', ' ') if vuln_subcategory is not None else None
|
|
||||||
rule_info['predicate'] = predicate.text.replace(' ', ' ') if predicate is not None else None
|
|
||||||
|
|
||||||
rules_list.append(rule_info)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_fortify_rules(src_path):
|
|
||||||
for root, dirs, files in os.walk(src_path):
|
|
||||||
for file_name in files:
|
|
||||||
if file_name.endswith('.xml'):
|
|
||||||
file_path = os.path.join(root, file_name)
|
|
||||||
extract_rules(file_path)
|
|
||||||
|
|
||||||
open('../../fortify_rules.json', 'w', encoding='utf-8').write(json.dumps(rules_list))
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
load_fortify_rules(r'C:\Users\yvling\Desktop\data')
|
|
||||||
87
utils/__init__.py
Normal file
87
utils/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import os
|
||||||
|
import logger
|
||||||
|
from app.constants import SUPPORT_LANGUAGE
|
||||||
|
|
||||||
|
|
||||||
|
def clean_tree(node, suffixes):
|
||||||
|
sub_keys_to_remove = []
|
||||||
|
for key, value in node.items():
|
||||||
|
if key == 'files':
|
||||||
|
continue
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
clean_tree(value, suffixes)
|
||||||
|
has_valid_content = 'files' in value or any(isinstance(sub_val, dict) and sub_val for sub_val in value.values())
|
||||||
|
if not has_valid_content:
|
||||||
|
sub_keys_to_remove.append(key)
|
||||||
|
|
||||||
|
for key in sub_keys_to_remove:
|
||||||
|
del node[key]
|
||||||
|
|
||||||
|
if all(key == 'files' for key in node) and ('files' not in node or not node['files']):
|
||||||
|
node.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def print_tree(tree):
|
||||||
|
formatted_str = ''
|
||||||
|
formatted = format_tree(tree)
|
||||||
|
|
||||||
|
for line in formatted:
|
||||||
|
formatted_str += f"{line}\n"
|
||||||
|
|
||||||
|
return formatted_str
|
||||||
|
|
||||||
|
|
||||||
|
def format_tree(node, level=0):
|
||||||
|
result = []
|
||||||
|
indent = ' ' * level
|
||||||
|
|
||||||
|
for key, value in node.items():
|
||||||
|
if key == 'files':
|
||||||
|
for file in value:
|
||||||
|
result.append(f"{indent}- {file}")
|
||||||
|
else:
|
||||||
|
result.append(f"{indent}- {key}/")
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
result.extend(format_tree(value, level + 1))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def build_directory_tree(path, lang, callback):
|
||||||
|
log = logger.Logger('app', callback=callback)
|
||||||
|
|
||||||
|
if lang in SUPPORT_LANGUAGE:
|
||||||
|
suffixes = SUPPORT_LANGUAGE[lang]
|
||||||
|
else:
|
||||||
|
log.error(f'不支持的语言:{lang}')
|
||||||
|
return
|
||||||
|
|
||||||
|
absolute_path = os.path.abspath(path).replace('\\', '/')
|
||||||
|
src_files = []
|
||||||
|
tree = {absolute_path: {}}
|
||||||
|
|
||||||
|
for root, _, files in os.walk(absolute_path):
|
||||||
|
relative_path = os.path.relpath(root, absolute_path)
|
||||||
|
current_node = tree[absolute_path]
|
||||||
|
|
||||||
|
if relative_path != '.':
|
||||||
|
parts = relative_path.split(os.sep)
|
||||||
|
for part in parts:
|
||||||
|
if part not in current_node:
|
||||||
|
current_node[part] = {}
|
||||||
|
current_node = current_node[part]
|
||||||
|
|
||||||
|
for suffix in suffixes:
|
||||||
|
lang_files = [file for file in files if file.endswith(suffix)]
|
||||||
|
if lang_files:
|
||||||
|
if 'files' not in current_node:
|
||||||
|
current_node['files'] = []
|
||||||
|
current_node['files'].extend(lang_files)
|
||||||
|
src_files.extend(lang_files)
|
||||||
|
|
||||||
|
if len(src_files) > 0:
|
||||||
|
clean_tree(tree[absolute_path], suffixes)
|
||||||
|
return print_tree(tree)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user