From 32e61b8d247aecfd7e9877a54b88e3ced1cc5d2d Mon Sep 17 00:00:00 2001 From: yv1ing Date: Sun, 9 Feb 2025 00:59:35 +0800 Subject: [PATCH] Complete the basic process of code audit --- .gitignore | 2 +- audit/__init__.py | 131 +++++++++++++++++++++++++++++++++++++++++++++ audit/callback.py | 13 +++++ audit/prompt.py | 36 +++++++++++++ logger/__init__.py | 40 ++++++++++++++ main.py | 23 ++++++++ requirements.txt | Bin 0 -> 2046 bytes 7 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 audit/__init__.py create mode 100644 audit/callback.py create mode 100644 audit/prompt.py create mode 100644 logger/__init__.py create mode 100644 main.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore index 15201ac..d86e4ff 100644 --- a/.gitignore +++ b/.gitignore @@ -165,7 +165,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ # PyPI configuration file .pypirc diff --git a/audit/__init__.py b/audit/__init__.py new file mode 100644 index 0000000..9f104b7 --- /dev/null +++ b/audit/__init__.py @@ -0,0 +1,131 @@ +import os +import re +import time +import uuid +import xml.etree.ElementTree as ET +from langchain_openai import ChatOpenAI +from langchain_openai import OpenAIEmbeddings +from langchain_core.messages import SystemMessage +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_community.chat_message_histories import ChatMessageHistory + +from logger import Logger +from audit import callback +from audit.prompt import SYSTEM_PROMPT + +reasoning_model = 'gemini-2.0-flash-thinking-exp' +embedding_model = 'text-embedding-3-large' + +xml_pattern = r'.*?' + + +class Audit: + def __init__(self): + self.raw_chain = None + self.source_files_list = [] + self.chat_history = ChatMessageHistory() + self.session_id = uuid.uuid4().hex + self.response_callback = callback.CustomCallbackHandler() + self.embedding = OpenAIEmbeddings(model=embedding_model) + self.llm = ChatOpenAI(model=reasoning_model, streaming=True, callbacks=[self.response_callback]) + self.log = Logger('audit') + self.prompt = ChatPromptTemplate.from_messages([ + SystemMessage(content=SYSTEM_PROMPT), + MessagesPlaceholder(variable_name='messages'), + ('human', '{input}'), + ]) + + def audit(self, callback_function): + self.log.info('Start auditing') + + input_content = '' + while True: + time.sleep(3) + result = self.send_message(input_content) + xml_match = re.search(xml_pattern, result, re.DOTALL) + + if xml_match: + xml_content = xml_match.group(0) + root = ET.fromstring(xml_content) + + action = root.find('action').text + content = root.find('content').text + + if action == 'QUERY STRUCTURE': + self.log.info('Request to query project structure') + input_content = '\n'.join(x for x in self.source_files_list) + continue + elif action == 'QUERY SOURCE': + self.log.info(f'Request source code: {content}') + input_content = open(content, 'r', encoding='utf-8').read() + continue + elif action == 'OUTPUT RESULT': + self.log.warning(f'Audit result: \n{content}\n') + callback_function(content) # Callback function, used to obtain results externally + input_content = '' + continue + else: + self.log.critical(f'Unknown action! {action}') + break + + def send_message(self, input_content): + self.response_callback.temp_content = '' + + if input_content == '': + input_content = 'nothing' + + input_dict = { + 'input': input_content, + 'context': '', + } + config_dict = { + 'configurable': {'session_id': self.session_id} + } + + self.raw_chain = self.prompt | self.llm + chain_with_message_history = RunnableWithMessageHistory( + self.raw_chain, + lambda session_id: self.chat_history, + input_messages_key='input', + history_messages_key='messages', + ) + + for _ in chain_with_message_history.stream(input_dict, config_dict): + pass + + return self.response_callback.temp_content + + def load_source_files(self, path, language): + self.log.info('Loading source files') + + if language == 'php': + suffixes = ['.php', '.php3', 'php4', 'php5'] + elif language == 'python': + suffixes = ['.py'] + elif language == 'java': + suffixes = ['.java'] + elif language == 'c': + suffixes = ['.c'] + elif language == 'c++': + suffixes = ['.cpp', 'cc'] + elif language == 'javascript': + suffixes = ['.js'] + elif language == 'go': + suffixes = ['.go'] + else: + self.log.critical('Language not supported!') + return + + for root, dirs, files in os.walk(path): + for file_name in files: + hit = False + for suffix in suffixes: + if file_name.endswith(suffix): + hit = True + break + + if hit: + self.source_files_list.append(os.path.join(root, file_name)) + + self.log.info(f'Finished loading source files. total files: {len(self.source_files_list)}') diff --git a/audit/callback.py b/audit/callback.py new file mode 100644 index 0000000..f9d7bf5 --- /dev/null +++ b/audit/callback.py @@ -0,0 +1,13 @@ +from langchain_core.callbacks import BaseCallbackHandler + + +class CustomCallbackHandler(BaseCallbackHandler): + def __init__(self): + self.temp_content = '' + + def on_llm_new_token(self, token: str, **kwargs) -> None: + self.temp_content += token + + def on_llm_end(self, response, **kwargs): + pass + diff --git a/audit/prompt.py b/audit/prompt.py new file mode 100644 index 0000000..ed9cbd4 --- /dev/null +++ b/audit/prompt.py @@ -0,0 +1,36 @@ +SYSTEM_PROMPT = """ +You are an intelligent code auditor. I will provide you with a source code. Please strictly follow the following requirements to conduct code audit. +All your output must strictly follow the following specifications. It is forbidden to output in any other form (including plain text, Markdown, etc.). +and it is forbidden to bring "`" when outputting: + +1. Query project structure: + +QUERY STRUCTURE + + + +2. Query code files + +QUERY SOURCE +the absolute path of the file you want to query + + +3. Output audit results + +OUTPUT RESULT +the audit results you want to output + + +4. End the audit task + +FINISH TASK + + + +Important things: +1. When the user sends you "nothing", you need to decide the next step based on the current audit progress; +2. When you make an action to query the project structure, the user will send you the following format (C:\\Users\\yvling\\Desktop\\PHP-Vuln\\src\\index.php), which is a text containing the absolute paths of several source code files. You need to construct the project structure that you can understand based on these contents; +3. When you need to query the content of a code file, please note that you can only query one file at a time. Please follow The above format outputs the absolute path of the file to be queried; +4. After you output the audit results, the user will reply with an empty string. Please make sure that all code files have been audited before ending the audit task; +5. In any case, you must strictly follow the several action formats given above for output. Any content outside the output format is prohibited. Do not try to ask or suggest; +""" diff --git a/logger/__init__.py b/logger/__init__.py new file mode 100644 index 0000000..7bb79a1 --- /dev/null +++ b/logger/__init__.py @@ -0,0 +1,40 @@ +from datetime import datetime + + +LOG_COLORS = { + 'DEBUG': '\033[94m', # 蓝色 + 'INFO': '\033[92m', # 绿色 + 'WARNING': '\033[93m', # 黄色 + 'ERROR': '\033[91m', # 红色 + 'CRITICAL': '\033[95m' # 紫色 +} +RESET_COLOR = '\033[0m' + + +def log_with_color(level, message): + color = LOG_COLORS.get(level, RESET_COLOR) + prefix = f"[{level}]" + date = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + formatted_message = f"{color}{date} {prefix} {message}{RESET_COLOR}" + + print(formatted_message) + + +class Logger: + def __init__(self, name): + pass + + def debug(self, message): + log_with_color("DEBUG", message) + + def info(self, message): + log_with_color("INFO", message) + + def warning(self, message): + log_with_color("WARNING", message) + + def error(self, message): + log_with_color("ERROR", message) + + def critical(self, message): + log_with_color("CRITICAL", message) diff --git a/main.py b/main.py new file mode 100644 index 0000000..5b29b82 --- /dev/null +++ b/main.py @@ -0,0 +1,23 @@ +import os +import warnings +from audit import Audit + + +warnings.simplefilter('ignore', FutureWarning) + +os.environ['OPENAI_API_BASE'] = 'https://yunwu.ai/v1' +os.environ['OPENAI_API_KEY'] = 'sk-SQhmr2wNQa2BpohUrxgJOFIDY9ODSxUkLQLWWlPD9qDNVsN1' + + +def result_callback(result): + pass + + +if __name__ == '__main__': + src_root = r'C:\Users\yvling\Desktop\PHP-Vuln' + language = 'php' + + audit = Audit() + audit.load_source_files(src_root, language) + audit.audit(result_callback) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..24201a55b45bb8fc231802b22f3419e25318d36d GIT binary patch literal 2046 zcmZ{l&2Ab&5QOI(DUU*AApXfA2j3GZQBFBo7}kJ+UBcqT_~Db(*X3qc*w$(Qd#Zb? ztE*@K{hgF$Y09EZ%2v;zRIzUL_gYU|KI<8k59OoQMR{M|X~z-!vR4HAW9<#~@Hy&X zS|R79j$GhCfK5!Ty+&CkT2+~qfxZvQSv{Y7ChVow*>lcXyPQof_+0h&V6!{&J=YGW zIdZCuJMAyaot7&)6N^NRg{87-=s=cA=s~dCl}+$pmZ!*)^>ThjHBze436sErfO)6b zEGnER|60G=$logey?!-1B_l2rbs0=%;@L(;u9jLpX}O*|oz(=NS{aQb2PkSXYA#fz z2MqNjb5Lf!(D(3o`Y;$fb|F07okm?%V60~{OvK|ceiZ%4%#MFlm995yL=~U0RLAi~ z$IJh}S4H??cg#$m?U__;8;a7eta~hH(ZdeWC{NwXg!AZ&PoWdF-07Z!Oon#zRb+dK z8iXbrb^1XuFg=_^x`S^;(v$v4W2p}B*YHiu0vg%1u!r5-5$N_PPN=>OZ_o#{^dx5n zPja5wsv0V>nu&28xS#07^BmyxKT&}v8PB@sD6&=-`nr@#cAWQ9^-aw8LHx`UcXCTG z1!kBobC+u6+^7<3%su3VVyEg5c_A*zhlxx#^gYJ@d1#J{re&qixj5hkaHE(5`)#Wn zoHGsNqSrfkxNg`n;dqU@or7JJw9=;-`ol~T!MWW%%YFBH-W%nnCi}-afKH|dWm|-j z1M%TrMcv=QCslRkZ({tjT$iuqOYG7kxH=Oj%(Xn(Bkfob-|(&n%rEtZn|j}4KNc2` z(w|cdw|k>!U2YYJQ&1|Ix!CWlAJ%vcX7T}G;>?Y8=W)kThFa)r@P&R z`u