Abandon the Fortify rule base, update prompt words, and optimize the data structure of the items to be audited

This commit is contained in:
2025-02-10 12:55:01 +08:00
parent c8bb6270a1
commit 5467f72bd7
10 changed files with 115 additions and 10954 deletions

View File

@@ -2,6 +2,12 @@
An automated code auditing tool powered by langchain. An automated code auditing tool powered by langchain.
Tool interface:
![](assets/img-01.png) ![](assets/img-01.png)
Result Export:
![](assets/img-02.png)
I welcome your suggestions for interesting tools :smile: I welcome your suggestions for interesting tools :smile:

View File

@@ -44,5 +44,5 @@ def update_config(key, value):
def audit_code(base_url, api_key, src_root, language, reasoning_model, embedding_model, process_output_callback, def audit_code(base_url, api_key, src_root, language, reasoning_model, embedding_model, process_output_callback,
result_output_callback, event): result_output_callback, event):
audit = Audit(base_url, api_key, reasoning_model, embedding_model, process_output_callback, result_output_callback) audit = Audit(base_url, api_key, reasoning_model, embedding_model, process_output_callback, result_output_callback)
audit.load_source_files(src_root, language) audit.build_directory_tree(src_root, language)
audit.audit(event) audit.audit(event)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 64 KiB

After

Width:  |  Height:  |  Size: 121 KiB

BIN
assets/img-02.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

View File

@@ -1,3 +1,4 @@
import json
import os import os
import re import re
import uuid import uuid
@@ -12,8 +13,6 @@ from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
from langchain_text_splitters import CharacterTextSplitter from langchain_text_splitters import CharacterTextSplitter
from audit.rules import FROTIFY_RULES
from logger import Logger from logger import Logger
from audit import callback from audit import callback
from audit.prompt import SYSTEM_PROMPT from audit.prompt import SYSTEM_PROMPT
@@ -23,12 +22,12 @@ xml_pattern = r'<root>.*?</root>'
class Audit: class Audit:
def __init__(self, base_url, api_key, reasoning_model, embedding_model, process_output_callback, result_output_callback): def __init__(self, base_url, api_key, reasoning_model, embedding_model, process_output_callback,
result_output_callback):
self.raw_chain = None self.raw_chain = None
self.source_files_list = [] self.directory_tree = None
self.reasoning_model = reasoning_model self.reasoning_model = reasoning_model
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.fortify_rules = FROTIFY_RULES
self.process_output_callback = process_output_callback self.process_output_callback = process_output_callback
self.result_output_callback = result_output_callback self.result_output_callback = result_output_callback
self.chat_history = ChatMessageHistory() self.chat_history = ChatMessageHistory()
@@ -73,10 +72,6 @@ class Audit:
]) ])
def audit(self, event): def audit(self, event):
if len(self.source_files_list) <= 0:
self.log.error('没有找到源代码文件')
return
self.log.info('开始代码审计流程') self.log.info('开始代码审计流程')
self.log.info(f'当前推理模型:{self.reasoning_model}') self.log.info(f'当前推理模型:{self.reasoning_model}')
self.log.info(f'当前嵌入模型:{self.embedding_model}') self.log.info(f'当前嵌入模型:{self.embedding_model}')
@@ -98,10 +93,20 @@ class Audit:
if xml_match := re.search(xml_pattern, result, re.DOTALL): if xml_match := re.search(xml_pattern, result, re.DOTALL):
try: try:
xml_content = xml_match.group(0) xml_content = xml_match.group(0)
xml_content = re.sub(
r'(<content>)(.*?)(</content>)',
r'\1<![CDATA[\2]]>\3',
xml_content,
flags=re.DOTALL
)
root = ET.fromstring(xml_content) root = ET.fromstring(xml_content)
action = root.find('action').text action = root.find('action').text
content = root.find('content').text content = root.find('content').text
if content and content.startswith('<![CDATA[') and content.endswith(']]>'):
content = content[9:-3]
except Exception as e: except Exception as e:
print(result) print(result)
print(e) print(e)
@@ -109,30 +114,35 @@ class Audit:
input_content = 'ILLEGAL OUTPUT' input_content = 'ILLEGAL OUTPUT'
continue continue
if action == 'QUERY STRUCTURE': try:
self.log.info('请求查询项目结构') if action == 'QUERY STRUCTURE':
input_content = '\n'.join(x for x in self.source_files_list) self.log.info('请求查询项目结构')
input_content = self.print_tree(self.directory_tree)
self.store_messages_in_faiss(input_content)
continue
elif action == 'QUERY SOURCE':
self.log.info(f'请求查询源代码:{content}')
input_content = open(content, 'r', encoding='utf-8').read()
self.store_messages_in_faiss(input_content)
continue
elif action == 'OUTPUT RESULT':
self.log.warning('输出代码审计结果')
dict_content = eval(content)
json_content = json.loads(json.dumps(dict_content))
output_content = f'漏洞类型:{json_content["漏洞类型"]}\n漏洞文件:{json_content["漏洞文件"]}\n相关代码:\n{json_content["相关代码"]}\n修复建议:\n{json_content["修复建议"]}\n'
self.result_output_callback(output_content)
self.store_messages_in_faiss(output_content)
input_content = 'ok'
continue
elif action == 'FINISH TASK':
self.log.info('代码审计任务已完成')
return
else:
self.log.error(f'动作指令未定义:{action}')
return
except Exception as e:
self.log.error(e)
continue continue
elif action == 'QUERY SOURCE':
self.log.info(f'请求查询源代码:{content}')
input_content = open(content, 'r', encoding='utf-8').read()
continue
elif action == 'QUERY FORTIFY':
self.log.info(f'请求查询规则库:{content}')
input_content = '\n'.join(x for x in self.fortify_rules if x == content)
continue
elif action == 'OUTPUT RESULT':
self.log.warning('输出代码审计结果')
self.result_output_callback(content)
self.store_messages_in_faiss(content)
input_content = 'ok'
continue
elif action == 'FINISH TASK':
self.log.info('代码审计任务已完成')
return
else:
self.log.error(f'动作指令未定义:{action}')
return
def send_message(self, input_content): def send_message(self, input_content):
self.response_callback.temp_content = '' self.response_callback.temp_content = ''
@@ -166,18 +176,58 @@ class Audit:
text_embedding = self.embedding.embed_query(message) text_embedding = self.embedding.embed_query(message)
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
self.messages_db.add_embeddings([(doc_id, text_embedding)], metadatas=[{"id": doc_id}]) self.messages_db.add_embeddings([(doc_id, text_embedding)], metadatas=[{"id": doc_id}])
self.log.info(f"代码审计结果已缓存,文档编号:{doc_id}")
def load_source_files(self, path, lang): def build_directory_tree(self, path, lang):
if lang in LANGUAGE: if lang in LANGUAGE:
suffixes = LANGUAGE[lang] suffixes = LANGUAGE[lang]
else: else:
self.log.error('不支持的编程语言') self.log.error(f'不支持的语言{lang}')
return return
for root, _, files in os.walk(path): absolute_path = os.path.abspath(path).replace('\\', '/')
self.source_files_list.extend( tree = {absolute_path: {}}
os.path.join(root, file).replace('\\', '/') for file in files if any(file.endswith(suffix) for suffix in suffixes)
)
self.log.info(f'源代码文件加载完成,共:{len(self.source_files_list)}') for root, _, files in os.walk(absolute_path):
relative_path = os.path.relpath(root, absolute_path)
current_node = tree[absolute_path]
if relative_path != '.':
parts = relative_path.split(os.sep)
for part in parts:
if part not in current_node:
current_node[part] = {}
current_node = current_node[part]
for suffix in suffixes:
lang_files = [file for file in files if file.endswith(suffix)]
if lang_files:
if 'files' not in current_node:
current_node['files'] = []
current_node['files'].extend(lang_files)
self.print_tree(tree)
self.directory_tree = tree
def format_tree(self, node, level=0):
result = []
indent = ' ' * level
for key, value in node.items():
if key == 'files':
for file in value:
result.append(f"{indent}- {file}")
else:
result.append(f"{indent}- {key}/")
if isinstance(value, dict):
result.extend(self.format_tree(value, level + 1))
return result
def print_tree(self, tree):
formatted_str = ''
formatted = self.format_tree(tree)
for line in formatted:
formatted_str += f"{line}\n"
# print(line)
return formatted_str

View File

@@ -1,58 +1,51 @@
SYSTEM_PROMPT = """ SYSTEM_PROMPT = """
You are a professional code audit security expert, responsible for helping users audit possible vulnerabilities and security issues in source code. You are a professional code audit security expert, responsible for helping users audit possible vulnerabilities and security issues in source code.
You will perform code audits according to the following process: You will perform code audits according to the following process:
1. Query project structure 1. Query project structure
You input the action command in the following format, and the user will send you the absolute path of all source files in the project below: You input the action command in the following format, and the user will send you the project structure below:
<root> <root>
<action>QUERY STRUCTURE</action> <action>QUERY STRUCTURE</action>
<content></content> <content></content>
</root> </root>
2. Query the vulnerability detection rule base 2. Query the source code
You input the action instructions in the following format, and the user will send you the vulnerability detection rule library extracted from Fortify as a reference for your code audit:
<root>
<action>QUERY FORTIFY</action>
<content>The language you want to query, options are: c, cpp, go, php, jsp, java, python, javascript</content>
</root>
3. Query the source code
You input the action command in the following format, and the user will send you the source code you need below: You input the action command in the following format, and the user will send you the source code you need below:
<root> <root>
<action>QUERY SOURCE</action> <action>QUERY SOURCE</action>
<content>the absolute path of the file you want to query</content> <content>the absolute path of the file you want to query</content>
</root> </root>
4. Output code audit results 3. Output code audit results
You input the code audit results in the following format, and the user will send you "ok", then you can proceed to the next step of the audit: You input the code audit results in the following format, and the user will send you "ok", then you can proceed to the next step of the audit:
<root> <root>
<action>OUTPUT RESULT</action> <action>OUTPUT RESULT</action>
<content>the audit results you want to output</content> <content>the audit results you want to output</content>
</root> </root>
5. Finish audit task 4. Finish audit task
When you are sure that all source code files have been audited, you can output the action instructions to end the task in the following format: When you are sure that all source code files have been audited, you can output the action instructions to end the task in the following format:
<root> <root>
<action>FINISH TASK</action> <action>FINISH TASK</action>
<content></content> <content></content>
</root> </root>
All your output can only be one of the five actions mentioned above. Any other form of output is strictly prohibited. All your output can only be one of the 4 actions mentioned above. Any other form of output is strictly prohibited.
Some additional information, which are some specifications when you perform actions: Some additional information, which are some specifications when you perform actions:
1. The format of the vulnerability detection rule base provided to you is as follows: 1. The project structure format sent to you is as follows. You need to construct the complete absolute path of the file you want to query based on these hierarchical relationships:
{ - C:/Users/yvling/Desktop/test/
'language': - dir_1/
'vuln_kingdom': - 1.php
'vuln_category': - dir_2/
} - 2.php
- dir_3/
- 3.php
2. When you output the code audit results, you must use Chinese output and follow the following format: 2. When you output the code audit results, you must use Chinese output and follow the following format(Python dict):
漏洞类型: {'漏洞类型': 'SQL Injection', '漏洞文件': 'main.java', '相关代码': '```java\nString id=request.getParameter("id");\nres = st.executeQuery("SELECT* FROM\"IWEBSEC\".\"user\" WHERE \"id\"="+id);\n```', '修复建议': 'your suggestions...'}
漏洞文件:
相关代码: Most important: Only output audit results with vulnerabilities, and prohibit output without vulnerabilities!
修复建议:
Some Mandatory regulations: Some Mandatory regulations:
1. Output Format: 1. Output Format:
@@ -73,4 +66,5 @@ Some Mandatory regulations:
b. High-risk vulnerabilities (such as injection and RCE) are handled first b. High-risk vulnerabilities (such as injection and RCE) are handled first
c. If multiple vulnerabilities are found in the same file, they need to be output multiple times c. If multiple vulnerabilities are found in the same file, they need to be output multiple times
d. For vulnerabilities that may span files, the audit can only begin after the relevant files have been queried as needed d. For vulnerabilities that may span files, the audit can only begin after the relevant files have been queried as needed
e. Only output audit results with vulnerabilities, and prohibit output without vulnerabilities
""" """

File diff suppressed because one or more lines are too long

View File

View File

@@ -1,44 +0,0 @@
import json
import os
import xml.etree.ElementTree as ET
rules_list = []
def extract_rules(xml_file):
tree = ET.parse(xml_file)
root = tree.getroot()
rules = root.findall('.//{xmlns://www.fortifysoftware.com/schema/rules}StructuralRule')
for rule in rules:
rule_info = {}
vuln_kingdom = rule.find('{xmlns://www.fortifysoftware.com/schema/rules}VulnKingdom')
vuln_category = rule.find('{xmlns://www.fortifysoftware.com/schema/rules}VulnCategory')
vuln_subcategory = rule.find('{xmlns://www.fortifysoftware.com/schema/rules}VulnSubcategory')
predicate = rule.find('{xmlns://www.fortifysoftware.com/schema/rules}Predicate')
rule_info['language'] = rule.get('language')
if rule_info['language'] in ['c', 'cpp', 'go', 'php', 'jsp', 'java', 'python', 'javascript']:
rule_info['vuln_kingdom'] = vuln_kingdom.text.replace(' ', ' ') if vuln_kingdom is not None else None
rule_info['vuln_category'] = vuln_category.text.replace(' ', ' ') if vuln_category is not None else None
rule_info['vuln_subcategory'] = vuln_subcategory.text.replace(' ', ' ') if vuln_subcategory is not None else None
rule_info['predicate'] = predicate.text.replace(' ', ' ') if predicate is not None else None
rules_list.append(rule_info)
def load_fortify_rules(src_path):
for root, dirs, files in os.walk(src_path):
for file_name in files:
if file_name.endswith('.xml'):
file_path = os.path.join(root, file_name)
extract_rules(file_path)
open('../../fortify_rules.json', 'w', encoding='utf-8').write(json.dumps(rules_list))
if __name__ == '__main__':
load_fortify_rules(r'C:\Users\yvling\Desktop\data')