From 6a0cab65d1f327b8175e11acfff799d72d8dcd5d Mon Sep 17 00:00:00 2001 From: yv1ing Date: Sun, 9 Feb 2025 12:05:41 +0800 Subject: [PATCH] Add context compression mechanism --- audit/__init__.py | 101 +++++++++++++++++++++++++++++----------------- audit/language.py | 9 +++++ audit/prompt.py | 1 + main.py | 6 +-- 4 files changed, 76 insertions(+), 41 deletions(-) create mode 100644 audit/language.py diff --git a/audit/__init__.py b/audit/__init__.py index 9f104b7..730e9c3 100644 --- a/audit/__init__.py +++ b/audit/__init__.py @@ -2,17 +2,23 @@ import os import re import time import uuid +import tiktoken import xml.etree.ElementTree as ET -from langchain_openai import ChatOpenAI -from langchain_openai import OpenAIEmbeddings +from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_core.messages import SystemMessage from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_community.vectorstores import FAISS from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_community.document_transformers import EmbeddingsRedundantFilter +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline +from langchain_text_splitters import CharacterTextSplitter from logger import Logger from audit import callback from audit.prompt import SYSTEM_PROMPT +from audit.language import LANGUAGE reasoning_model = 'gemini-2.0-flash-thinking-exp' embedding_model = 'text-embedding-3-large' @@ -24,12 +30,36 @@ class Audit: def __init__(self): self.raw_chain = None self.source_files_list = [] + self.max_token = 4096 self.chat_history = ChatMessageHistory() self.session_id = uuid.uuid4().hex self.response_callback = callback.CustomCallbackHandler() self.embedding = OpenAIEmbeddings(model=embedding_model) - self.llm = ChatOpenAI(model=reasoning_model, streaming=True, callbacks=[self.response_callback]) + self.llm = ChatOpenAI( + model=reasoning_model, + streaming=True, + callbacks=[self.response_callback] + ) self.log = Logger('audit') + self.splitter = CharacterTextSplitter( + chunk_size=300, + chunk_overlap=0, + separator=". ", + ) + self.messages_db = FAISS.from_texts(['nothing'], self.embedding) + self.retriever = self.messages_db.as_retriever() + self.redundant_filter = EmbeddingsRedundantFilter(embeddings=self.embedding) + self.relevant_filter = EmbeddingsFilter( + embeddings=self.embedding, + similarity_threshold=0.76, + ) + self.pipeline_compressor = DocumentCompressorPipeline( + transformers=[self.splitter, self.redundant_filter, self.relevant_filter] + ) + self.compression_retriever = ContextualCompressionRetriever( + base_compressor=self.pipeline_compressor, + base_retriever=self.retriever, + ) self.prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=SYSTEM_PROMPT), MessagesPlaceholder(variable_name='messages'), @@ -41,19 +71,22 @@ class Audit: input_content = '' while True: - time.sleep(3) result = self.send_message(input_content) - xml_match = re.search(xml_pattern, result, re.DOTALL) - if xml_match: - xml_content = xml_match.group(0) - root = ET.fromstring(xml_content) + if xml_match := re.search(xml_pattern, result, re.DOTALL): + try: + xml_content = xml_match.group(0) + root = ET.fromstring(xml_content) - action = root.find('action').text - content = root.find('content').text + action = root.find('action').text + content = root.find('content').text + except Exception as e: + self.log.error(f'Illegal output, try to correct') + input_content = 'ILLEGAL OUTPUT' + continue if action == 'QUERY STRUCTURE': - self.log.info('Request to query project structure') + self.log.info('Request project structure') input_content = '\n'.join(x for x in self.source_files_list) continue elif action == 'QUERY SOURCE': @@ -61,7 +94,8 @@ class Audit: input_content = open(content, 'r', encoding='utf-8').read() continue elif action == 'OUTPUT RESULT': - self.log.warning(f'Audit result: \n{content}\n') + self.log.warning(f'Audit result: \n\n{content}') + self.store_messages_in_faiss(content) callback_function(content) # Callback function, used to obtain results externally input_content = '' continue @@ -71,13 +105,14 @@ class Audit: def send_message(self, input_content): self.response_callback.temp_content = '' + compressed_context = self.compression_retriever.invoke(input_content) if input_content == '': input_content = 'nothing' input_dict = { 'input': input_content, - 'context': '', + 'context': compressed_context, } config_dict = { 'configurable': {'session_id': self.session_id} @@ -91,41 +126,31 @@ class Audit: history_messages_key='messages', ) + self.log.debug(f'Chat messages: {input_dict}') + for _ in chain_with_message_history.stream(input_dict, config_dict): pass return self.response_callback.temp_content - def load_source_files(self, path, language): + def store_messages_in_faiss(self, message): + text_embedding = self.embedding.embed_query(message) + doc_id = str(uuid.uuid4()) + self.messages_db.add_embeddings([(doc_id, text_embedding)], metadatas=[{"id": doc_id}]) + self.log.info(f"Audit result stored in messages_db with ID: {doc_id}") + + def load_source_files(self, path, lang): self.log.info('Loading source files') - if language == 'php': - suffixes = ['.php', '.php3', 'php4', 'php5'] - elif language == 'python': - suffixes = ['.py'] - elif language == 'java': - suffixes = ['.java'] - elif language == 'c': - suffixes = ['.c'] - elif language == 'c++': - suffixes = ['.cpp', 'cc'] - elif language == 'javascript': - suffixes = ['.js'] - elif language == 'go': - suffixes = ['.go'] + if lang in LANGUAGE: + suffixes = LANGUAGE[lang] else: self.log.critical('Language not supported!') return - for root, dirs, files in os.walk(path): - for file_name in files: - hit = False - for suffix in suffixes: - if file_name.endswith(suffix): - hit = True - break - - if hit: - self.source_files_list.append(os.path.join(root, file_name)) + for root, _, files in os.walk(path): + self.source_files_list.extend( + os.path.join(root, file) for file in files if any(file.endswith(suffix) for suffix in suffixes) + ) self.log.info(f'Finished loading source files. total files: {len(self.source_files_list)}') diff --git a/audit/language.py b/audit/language.py new file mode 100644 index 0000000..7215cd1 --- /dev/null +++ b/audit/language.py @@ -0,0 +1,9 @@ +LANGUAGE = { + 'c': ['.c'], + 'c++': ['.cc', '.cpp'], + 'go': ['.go'], + 'php': ['php', 'php3', 'php4', 'php5', 'phtml'], + 'java': ['.java'], + 'python': ['.py'], + 'javascript': ['.js'], +} diff --git a/audit/prompt.py b/audit/prompt.py index ed9cbd4..8ddff39 100644 --- a/audit/prompt.py +++ b/audit/prompt.py @@ -33,4 +33,5 @@ Important things: 3. When you need to query the content of a code file, please note that you can only query one file at a time. Please follow The above format outputs the absolute path of the file to be queried; 4. After you output the audit results, the user will reply with an empty string. Please make sure that all code files have been audited before ending the audit task; 5. In any case, you must strictly follow the several action formats given above for output. Any content outside the output format is prohibited. Do not try to ask or suggest; +6. When the user prompts "ILLEGAL OUTPUT", it means that your output violates the user's specifications. Please confirm again that all your output must comply with the user's specifications. """ diff --git a/main.py b/main.py index 5b29b82..65e346c 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ from audit import Audit warnings.simplefilter('ignore', FutureWarning) os.environ['OPENAI_API_BASE'] = 'https://yunwu.ai/v1' -os.environ['OPENAI_API_KEY'] = 'sk-SQhmr2wNQa2BpohUrxgJOFIDY9ODSxUkLQLWWlPD9qDNVsN1' +os.environ['OPENAI_API_KEY'] = 'sk-zpkHfWT0Zhvzc79lX11WS4dEyg5CkQ3RdZOSNDoLADaitfVM' def result_callback(result): @@ -14,8 +14,8 @@ def result_callback(result): if __name__ == '__main__': - src_root = r'C:\Users\yvling\Desktop\PHP-Vuln' - language = 'php' + src_root = r'C:\Users\yvling\Desktop\JavaSecLab' + language = 'java' audit = Audit() audit.load_source_files(src_root, language)