Abandon the Fortify rule base, update prompt words, and optimize the data structure of the items to be audited

This commit is contained in:
2025-02-10 12:55:01 +08:00
parent c8bb6270a1
commit 5467f72bd7
10 changed files with 115 additions and 10954 deletions

View File

@@ -1,3 +1,4 @@
import json
import os
import re
import uuid
@@ -12,8 +13,6 @@ from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
from langchain_text_splitters import CharacterTextSplitter
from audit.rules import FROTIFY_RULES
from logger import Logger
from audit import callback
from audit.prompt import SYSTEM_PROMPT
@@ -23,12 +22,12 @@ xml_pattern = r'<root>.*?</root>'
class Audit:
def __init__(self, base_url, api_key, reasoning_model, embedding_model, process_output_callback, result_output_callback):
def __init__(self, base_url, api_key, reasoning_model, embedding_model, process_output_callback,
result_output_callback):
self.raw_chain = None
self.source_files_list = []
self.directory_tree = None
self.reasoning_model = reasoning_model
self.embedding_model = embedding_model
self.fortify_rules = FROTIFY_RULES
self.process_output_callback = process_output_callback
self.result_output_callback = result_output_callback
self.chat_history = ChatMessageHistory()
@@ -73,10 +72,6 @@ class Audit:
])
def audit(self, event):
if len(self.source_files_list) <= 0:
self.log.error('没有找到源代码文件')
return
self.log.info('开始代码审计流程')
self.log.info(f'当前推理模型:{self.reasoning_model}')
self.log.info(f'当前嵌入模型:{self.embedding_model}')
@@ -98,10 +93,20 @@ class Audit:
if xml_match := re.search(xml_pattern, result, re.DOTALL):
try:
xml_content = xml_match.group(0)
xml_content = re.sub(
r'(<content>)(.*?)(</content>)',
r'\1<![CDATA[\2]]>\3',
xml_content,
flags=re.DOTALL
)
root = ET.fromstring(xml_content)
action = root.find('action').text
content = root.find('content').text
if content and content.startswith('<![CDATA[') and content.endswith(']]>'):
content = content[9:-3]
except Exception as e:
print(result)
print(e)
@@ -109,30 +114,35 @@ class Audit:
input_content = 'ILLEGAL OUTPUT'
continue
if action == 'QUERY STRUCTURE':
self.log.info('请求查询项目结构')
input_content = '\n'.join(x for x in self.source_files_list)
try:
if action == 'QUERY STRUCTURE':
self.log.info('请求查询项目结构')
input_content = self.print_tree(self.directory_tree)
self.store_messages_in_faiss(input_content)
continue
elif action == 'QUERY SOURCE':
self.log.info(f'请求查询源代码:{content}')
input_content = open(content, 'r', encoding='utf-8').read()
self.store_messages_in_faiss(input_content)
continue
elif action == 'OUTPUT RESULT':
self.log.warning('输出代码审计结果')
dict_content = eval(content)
json_content = json.loads(json.dumps(dict_content))
output_content = f'漏洞类型:{json_content["漏洞类型"]}\n漏洞文件:{json_content["漏洞文件"]}\n相关代码:\n{json_content["相关代码"]}\n修复建议:\n{json_content["修复建议"]}\n'
self.result_output_callback(output_content)
self.store_messages_in_faiss(output_content)
input_content = 'ok'
continue
elif action == 'FINISH TASK':
self.log.info('代码审计任务已完成')
return
else:
self.log.error(f'动作指令未定义:{action}')
return
except Exception as e:
self.log.error(e)
continue
elif action == 'QUERY SOURCE':
self.log.info(f'请求查询源代码:{content}')
input_content = open(content, 'r', encoding='utf-8').read()
continue
elif action == 'QUERY FORTIFY':
self.log.info(f'请求查询规则库:{content}')
input_content = '\n'.join(x for x in self.fortify_rules if x == content)
continue
elif action == 'OUTPUT RESULT':
self.log.warning('输出代码审计结果')
self.result_output_callback(content)
self.store_messages_in_faiss(content)
input_content = 'ok'
continue
elif action == 'FINISH TASK':
self.log.info('代码审计任务已完成')
return
else:
self.log.error(f'动作指令未定义:{action}')
return
def send_message(self, input_content):
self.response_callback.temp_content = ''
@@ -166,18 +176,58 @@ class Audit:
text_embedding = self.embedding.embed_query(message)
doc_id = str(uuid.uuid4())
self.messages_db.add_embeddings([(doc_id, text_embedding)], metadatas=[{"id": doc_id}])
self.log.info(f"代码审计结果已缓存,文档编号:{doc_id}")
def load_source_files(self, path, lang):
def build_directory_tree(self, path, lang):
if lang in LANGUAGE:
suffixes = LANGUAGE[lang]
else:
self.log.error('不支持的编程语言')
self.log.error(f'不支持的语言{lang}')
return
for root, _, files in os.walk(path):
self.source_files_list.extend(
os.path.join(root, file).replace('\\', '/') for file in files if any(file.endswith(suffix) for suffix in suffixes)
)
absolute_path = os.path.abspath(path).replace('\\', '/')
tree = {absolute_path: {}}
self.log.info(f'源代码文件加载完成,共:{len(self.source_files_list)}')
for root, _, files in os.walk(absolute_path):
relative_path = os.path.relpath(root, absolute_path)
current_node = tree[absolute_path]
if relative_path != '.':
parts = relative_path.split(os.sep)
for part in parts:
if part not in current_node:
current_node[part] = {}
current_node = current_node[part]
for suffix in suffixes:
lang_files = [file for file in files if file.endswith(suffix)]
if lang_files:
if 'files' not in current_node:
current_node['files'] = []
current_node['files'].extend(lang_files)
self.print_tree(tree)
self.directory_tree = tree
def format_tree(self, node, level=0):
result = []
indent = ' ' * level
for key, value in node.items():
if key == 'files':
for file in value:
result.append(f"{indent}- {file}")
else:
result.append(f"{indent}- {key}/")
if isinstance(value, dict):
result.extend(self.format_tree(value, level + 1))
return result
def print_tree(self, tree):
formatted_str = ''
formatted = self.format_tree(tree)
for line in formatted:
formatted_str += f"{line}\n"
# print(line)
return formatted_str