Complete the basic process of code audit

This commit is contained in:
2025-02-09 00:59:35 +08:00
parent eaff2123e7
commit 32e61b8d24
7 changed files with 244 additions and 1 deletions

2
.gitignore vendored
View File

@@ -165,7 +165,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
# PyPI configuration file
.pypirc

131
audit/__init__.py Normal file
View File

@@ -0,0 +1,131 @@
import os
import re
import time
import uuid
import xml.etree.ElementTree as ET
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_core.messages import SystemMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory
from logger import Logger
from audit import callback
from audit.prompt import SYSTEM_PROMPT
reasoning_model = 'gemini-2.0-flash-thinking-exp'
embedding_model = 'text-embedding-3-large'
xml_pattern = r'<root>.*?</root>'
class Audit:
def __init__(self):
self.raw_chain = None
self.source_files_list = []
self.chat_history = ChatMessageHistory()
self.session_id = uuid.uuid4().hex
self.response_callback = callback.CustomCallbackHandler()
self.embedding = OpenAIEmbeddings(model=embedding_model)
self.llm = ChatOpenAI(model=reasoning_model, streaming=True, callbacks=[self.response_callback])
self.log = Logger('audit')
self.prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=SYSTEM_PROMPT),
MessagesPlaceholder(variable_name='messages'),
('human', '{input}'),
])
def audit(self, callback_function):
self.log.info('Start auditing')
input_content = ''
while True:
time.sleep(3)
result = self.send_message(input_content)
xml_match = re.search(xml_pattern, result, re.DOTALL)
if xml_match:
xml_content = xml_match.group(0)
root = ET.fromstring(xml_content)
action = root.find('action').text
content = root.find('content').text
if action == 'QUERY STRUCTURE':
self.log.info('Request to query project structure')
input_content = '\n'.join(x for x in self.source_files_list)
continue
elif action == 'QUERY SOURCE':
self.log.info(f'Request source code: {content}')
input_content = open(content, 'r', encoding='utf-8').read()
continue
elif action == 'OUTPUT RESULT':
self.log.warning(f'Audit result: \n{content}\n')
callback_function(content) # Callback function, used to obtain results externally
input_content = ''
continue
else:
self.log.critical(f'Unknown action! {action}')
break
def send_message(self, input_content):
self.response_callback.temp_content = ''
if input_content == '':
input_content = 'nothing'
input_dict = {
'input': input_content,
'context': '',
}
config_dict = {
'configurable': {'session_id': self.session_id}
}
self.raw_chain = self.prompt | self.llm
chain_with_message_history = RunnableWithMessageHistory(
self.raw_chain,
lambda session_id: self.chat_history,
input_messages_key='input',
history_messages_key='messages',
)
for _ in chain_with_message_history.stream(input_dict, config_dict):
pass
return self.response_callback.temp_content
def load_source_files(self, path, language):
self.log.info('Loading source files')
if language == 'php':
suffixes = ['.php', '.php3', 'php4', 'php5']
elif language == 'python':
suffixes = ['.py']
elif language == 'java':
suffixes = ['.java']
elif language == 'c':
suffixes = ['.c']
elif language == 'c++':
suffixes = ['.cpp', 'cc']
elif language == 'javascript':
suffixes = ['.js']
elif language == 'go':
suffixes = ['.go']
else:
self.log.critical('Language not supported!')
return
for root, dirs, files in os.walk(path):
for file_name in files:
hit = False
for suffix in suffixes:
if file_name.endswith(suffix):
hit = True
break
if hit:
self.source_files_list.append(os.path.join(root, file_name))
self.log.info(f'Finished loading source files. total files: {len(self.source_files_list)}')

13
audit/callback.py Normal file
View File

@@ -0,0 +1,13 @@
from langchain_core.callbacks import BaseCallbackHandler
class CustomCallbackHandler(BaseCallbackHandler):
def __init__(self):
self.temp_content = ''
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.temp_content += token
def on_llm_end(self, response, **kwargs):
pass

36
audit/prompt.py Normal file
View File

@@ -0,0 +1,36 @@
SYSTEM_PROMPT = """
You are an intelligent code auditor. I will provide you with a source code. Please strictly follow the following requirements to conduct code audit.
All your output must strictly follow the following specifications. It is forbidden to output in any other form (including plain text, Markdown, etc.).
and it is forbidden to bring "`" when outputting:
1. Query project structure:
<root>
<action>QUERY STRUCTURE</action>
<content></content>
</root>
2. Query code files
<root>
<action>QUERY SOURCE</action>
<content>the absolute path of the file you want to query</content>
</root>
3. Output audit results
<root>
<action>OUTPUT RESULT</action>
<content>the audit results you want to output</content>
</root>
4. End the audit task
<root>
<action>FINISH TASK</action>
<content></content>
</root>
Important things:
1. When the user sends you "nothing", you need to decide the next step based on the current audit progress;
2. When you make an action to query the project structure, the user will send you the following format (C:\\Users\\yvling\\Desktop\\PHP-Vuln\\src\\index.php), which is a text containing the absolute paths of several source code files. You need to construct the project structure that you can understand based on these contents;
3. When you need to query the content of a code file, please note that you can only query one file at a time. Please follow The above format outputs the absolute path of the file to be queried;
4. After you output the audit results, the user will reply with an empty string. Please make sure that all code files have been audited before ending the audit task;
5. In any case, you must strictly follow the several action formats given above for output. Any content outside the output format is prohibited. Do not try to ask or suggest;
"""

40
logger/__init__.py Normal file
View File

@@ -0,0 +1,40 @@
from datetime import datetime
LOG_COLORS = {
'DEBUG': '\033[94m', # 蓝色
'INFO': '\033[92m', # 绿色
'WARNING': '\033[93m', # 黄色
'ERROR': '\033[91m', # 红色
'CRITICAL': '\033[95m' # 紫色
}
RESET_COLOR = '\033[0m'
def log_with_color(level, message):
color = LOG_COLORS.get(level, RESET_COLOR)
prefix = f"[{level}]"
date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
formatted_message = f"{color}{date} {prefix} {message}{RESET_COLOR}"
print(formatted_message)
class Logger:
def __init__(self, name):
pass
def debug(self, message):
log_with_color("DEBUG", message)
def info(self, message):
log_with_color("INFO", message)
def warning(self, message):
log_with_color("WARNING", message)
def error(self, message):
log_with_color("ERROR", message)
def critical(self, message):
log_with_color("CRITICAL", message)

23
main.py Normal file
View File

@@ -0,0 +1,23 @@
import os
import warnings
from audit import Audit
warnings.simplefilter('ignore', FutureWarning)
os.environ['OPENAI_API_BASE'] = 'https://yunwu.ai/v1'
os.environ['OPENAI_API_KEY'] = 'sk-SQhmr2wNQa2BpohUrxgJOFIDY9ODSxUkLQLWWlPD9qDNVsN1'
def result_callback(result):
pass
if __name__ == '__main__':
src_root = r'C:\Users\yvling\Desktop\PHP-Vuln'
language = 'php'
audit = Audit()
audit.load_source_files(src_root, language)
audit.audit(result_callback)

BIN
requirements.txt Normal file

Binary file not shown.