diff --git a/.gitignore b/.gitignore index 15201ac..d86e4ff 100644 --- a/.gitignore +++ b/.gitignore @@ -165,7 +165,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ # PyPI configuration file .pypirc diff --git a/audit/__init__.py b/audit/__init__.py new file mode 100644 index 0000000..9f104b7 --- /dev/null +++ b/audit/__init__.py @@ -0,0 +1,131 @@ +import os +import re +import time +import uuid +import xml.etree.ElementTree as ET +from langchain_openai import ChatOpenAI +from langchain_openai import OpenAIEmbeddings +from langchain_core.messages import SystemMessage +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_community.chat_message_histories import ChatMessageHistory + +from logger import Logger +from audit import callback +from audit.prompt import SYSTEM_PROMPT + +reasoning_model = 'gemini-2.0-flash-thinking-exp' +embedding_model = 'text-embedding-3-large' + +xml_pattern = r'.*?' + + +class Audit: + def __init__(self): + self.raw_chain = None + self.source_files_list = [] + self.chat_history = ChatMessageHistory() + self.session_id = uuid.uuid4().hex + self.response_callback = callback.CustomCallbackHandler() + self.embedding = OpenAIEmbeddings(model=embedding_model) + self.llm = ChatOpenAI(model=reasoning_model, streaming=True, callbacks=[self.response_callback]) + self.log = Logger('audit') + self.prompt = ChatPromptTemplate.from_messages([ + SystemMessage(content=SYSTEM_PROMPT), + MessagesPlaceholder(variable_name='messages'), + ('human', '{input}'), + ]) + + def audit(self, callback_function): + self.log.info('Start auditing') + + input_content = '' + while True: + time.sleep(3) + result = self.send_message(input_content) + xml_match = re.search(xml_pattern, result, re.DOTALL) + + if xml_match: + xml_content = xml_match.group(0) + root = ET.fromstring(xml_content) + + action = root.find('action').text + content = root.find('content').text + + if action == 'QUERY STRUCTURE': + self.log.info('Request to query project structure') + input_content = '\n'.join(x for x in self.source_files_list) + continue + elif action == 'QUERY SOURCE': + self.log.info(f'Request source code: {content}') + input_content = open(content, 'r', encoding='utf-8').read() + continue + elif action == 'OUTPUT RESULT': + self.log.warning(f'Audit result: \n{content}\n') + callback_function(content) # Callback function, used to obtain results externally + input_content = '' + continue + else: + self.log.critical(f'Unknown action! {action}') + break + + def send_message(self, input_content): + self.response_callback.temp_content = '' + + if input_content == '': + input_content = 'nothing' + + input_dict = { + 'input': input_content, + 'context': '', + } + config_dict = { + 'configurable': {'session_id': self.session_id} + } + + self.raw_chain = self.prompt | self.llm + chain_with_message_history = RunnableWithMessageHistory( + self.raw_chain, + lambda session_id: self.chat_history, + input_messages_key='input', + history_messages_key='messages', + ) + + for _ in chain_with_message_history.stream(input_dict, config_dict): + pass + + return self.response_callback.temp_content + + def load_source_files(self, path, language): + self.log.info('Loading source files') + + if language == 'php': + suffixes = ['.php', '.php3', 'php4', 'php5'] + elif language == 'python': + suffixes = ['.py'] + elif language == 'java': + suffixes = ['.java'] + elif language == 'c': + suffixes = ['.c'] + elif language == 'c++': + suffixes = ['.cpp', 'cc'] + elif language == 'javascript': + suffixes = ['.js'] + elif language == 'go': + suffixes = ['.go'] + else: + self.log.critical('Language not supported!') + return + + for root, dirs, files in os.walk(path): + for file_name in files: + hit = False + for suffix in suffixes: + if file_name.endswith(suffix): + hit = True + break + + if hit: + self.source_files_list.append(os.path.join(root, file_name)) + + self.log.info(f'Finished loading source files. total files: {len(self.source_files_list)}') diff --git a/audit/callback.py b/audit/callback.py new file mode 100644 index 0000000..f9d7bf5 --- /dev/null +++ b/audit/callback.py @@ -0,0 +1,13 @@ +from langchain_core.callbacks import BaseCallbackHandler + + +class CustomCallbackHandler(BaseCallbackHandler): + def __init__(self): + self.temp_content = '' + + def on_llm_new_token(self, token: str, **kwargs) -> None: + self.temp_content += token + + def on_llm_end(self, response, **kwargs): + pass + diff --git a/audit/prompt.py b/audit/prompt.py new file mode 100644 index 0000000..ed9cbd4 --- /dev/null +++ b/audit/prompt.py @@ -0,0 +1,36 @@ +SYSTEM_PROMPT = """ +You are an intelligent code auditor. I will provide you with a source code. Please strictly follow the following requirements to conduct code audit. +All your output must strictly follow the following specifications. It is forbidden to output in any other form (including plain text, Markdown, etc.). +and it is forbidden to bring "`" when outputting: + +1. Query project structure: + +QUERY STRUCTURE + + + +2. Query code files + +QUERY SOURCE +the absolute path of the file you want to query + + +3. Output audit results + +OUTPUT RESULT +the audit results you want to output + + +4. End the audit task + +FINISH TASK + + + +Important things: +1. When the user sends you "nothing", you need to decide the next step based on the current audit progress; +2. When you make an action to query the project structure, the user will send you the following format (C:\\Users\\yvling\\Desktop\\PHP-Vuln\\src\\index.php), which is a text containing the absolute paths of several source code files. You need to construct the project structure that you can understand based on these contents; +3. When you need to query the content of a code file, please note that you can only query one file at a time. Please follow The above format outputs the absolute path of the file to be queried; +4. After you output the audit results, the user will reply with an empty string. Please make sure that all code files have been audited before ending the audit task; +5. In any case, you must strictly follow the several action formats given above for output. Any content outside the output format is prohibited. Do not try to ask or suggest; +""" diff --git a/logger/__init__.py b/logger/__init__.py new file mode 100644 index 0000000..7bb79a1 --- /dev/null +++ b/logger/__init__.py @@ -0,0 +1,40 @@ +from datetime import datetime + + +LOG_COLORS = { + 'DEBUG': '\033[94m', # 蓝色 + 'INFO': '\033[92m', # 绿色 + 'WARNING': '\033[93m', # 黄色 + 'ERROR': '\033[91m', # 红色 + 'CRITICAL': '\033[95m' # 紫色 +} +RESET_COLOR = '\033[0m' + + +def log_with_color(level, message): + color = LOG_COLORS.get(level, RESET_COLOR) + prefix = f"[{level}]" + date = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + formatted_message = f"{color}{date} {prefix} {message}{RESET_COLOR}" + + print(formatted_message) + + +class Logger: + def __init__(self, name): + pass + + def debug(self, message): + log_with_color("DEBUG", message) + + def info(self, message): + log_with_color("INFO", message) + + def warning(self, message): + log_with_color("WARNING", message) + + def error(self, message): + log_with_color("ERROR", message) + + def critical(self, message): + log_with_color("CRITICAL", message) diff --git a/main.py b/main.py new file mode 100644 index 0000000..5b29b82 --- /dev/null +++ b/main.py @@ -0,0 +1,23 @@ +import os +import warnings +from audit import Audit + + +warnings.simplefilter('ignore', FutureWarning) + +os.environ['OPENAI_API_BASE'] = 'https://yunwu.ai/v1' +os.environ['OPENAI_API_KEY'] = 'sk-SQhmr2wNQa2BpohUrxgJOFIDY9ODSxUkLQLWWlPD9qDNVsN1' + + +def result_callback(result): + pass + + +if __name__ == '__main__': + src_root = r'C:\Users\yvling\Desktop\PHP-Vuln' + language = 'php' + + audit = Audit() + audit.load_source_files(src_root, language) + audit.audit(result_callback) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..24201a5 Binary files /dev/null and b/requirements.txt differ