Add context compression mechanism

This commit is contained in:
2025-02-09 12:05:41 +08:00
parent 32e61b8d24
commit 6a0cab65d1
4 changed files with 76 additions and 41 deletions

View File

@@ -2,17 +2,23 @@ import os
import re
import time
import uuid
import tiktoken
import xml.etree.ElementTree as ET
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.messages import SystemMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.vectorstores import FAISS
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
from langchain_text_splitters import CharacterTextSplitter
from logger import Logger
from audit import callback
from audit.prompt import SYSTEM_PROMPT
from audit.language import LANGUAGE
reasoning_model = 'gemini-2.0-flash-thinking-exp'
embedding_model = 'text-embedding-3-large'
@@ -24,12 +30,36 @@ class Audit:
def __init__(self):
self.raw_chain = None
self.source_files_list = []
self.max_token = 4096
self.chat_history = ChatMessageHistory()
self.session_id = uuid.uuid4().hex
self.response_callback = callback.CustomCallbackHandler()
self.embedding = OpenAIEmbeddings(model=embedding_model)
self.llm = ChatOpenAI(model=reasoning_model, streaming=True, callbacks=[self.response_callback])
self.llm = ChatOpenAI(
model=reasoning_model,
streaming=True,
callbacks=[self.response_callback]
)
self.log = Logger('audit')
self.splitter = CharacterTextSplitter(
chunk_size=300,
chunk_overlap=0,
separator=". ",
)
self.messages_db = FAISS.from_texts(['nothing'], self.embedding)
self.retriever = self.messages_db.as_retriever()
self.redundant_filter = EmbeddingsRedundantFilter(embeddings=self.embedding)
self.relevant_filter = EmbeddingsFilter(
embeddings=self.embedding,
similarity_threshold=0.76,
)
self.pipeline_compressor = DocumentCompressorPipeline(
transformers=[self.splitter, self.redundant_filter, self.relevant_filter]
)
self.compression_retriever = ContextualCompressionRetriever(
base_compressor=self.pipeline_compressor,
base_retriever=self.retriever,
)
self.prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=SYSTEM_PROMPT),
MessagesPlaceholder(variable_name='messages'),
@@ -41,19 +71,22 @@ class Audit:
input_content = ''
while True:
time.sleep(3)
result = self.send_message(input_content)
xml_match = re.search(xml_pattern, result, re.DOTALL)
if xml_match:
if xml_match := re.search(xml_pattern, result, re.DOTALL):
try:
xml_content = xml_match.group(0)
root = ET.fromstring(xml_content)
action = root.find('action').text
content = root.find('content').text
except Exception as e:
self.log.error(f'Illegal output, try to correct')
input_content = 'ILLEGAL OUTPUT'
continue
if action == 'QUERY STRUCTURE':
self.log.info('Request to query project structure')
self.log.info('Request project structure')
input_content = '\n'.join(x for x in self.source_files_list)
continue
elif action == 'QUERY SOURCE':
@@ -61,7 +94,8 @@ class Audit:
input_content = open(content, 'r', encoding='utf-8').read()
continue
elif action == 'OUTPUT RESULT':
self.log.warning(f'Audit result: \n{content}\n')
self.log.warning(f'Audit result: \n\n{content}')
self.store_messages_in_faiss(content)
callback_function(content) # Callback function, used to obtain results externally
input_content = ''
continue
@@ -71,13 +105,14 @@ class Audit:
def send_message(self, input_content):
self.response_callback.temp_content = ''
compressed_context = self.compression_retriever.invoke(input_content)
if input_content == '':
input_content = 'nothing'
input_dict = {
'input': input_content,
'context': '',
'context': compressed_context,
}
config_dict = {
'configurable': {'session_id': self.session_id}
@@ -91,41 +126,31 @@ class Audit:
history_messages_key='messages',
)
self.log.debug(f'Chat messages: {input_dict}')
for _ in chain_with_message_history.stream(input_dict, config_dict):
pass
return self.response_callback.temp_content
def load_source_files(self, path, language):
def store_messages_in_faiss(self, message):
text_embedding = self.embedding.embed_query(message)
doc_id = str(uuid.uuid4())
self.messages_db.add_embeddings([(doc_id, text_embedding)], metadatas=[{"id": doc_id}])
self.log.info(f"Audit result stored in messages_db with ID: {doc_id}")
def load_source_files(self, path, lang):
self.log.info('Loading source files')
if language == 'php':
suffixes = ['.php', '.php3', 'php4', 'php5']
elif language == 'python':
suffixes = ['.py']
elif language == 'java':
suffixes = ['.java']
elif language == 'c':
suffixes = ['.c']
elif language == 'c++':
suffixes = ['.cpp', 'cc']
elif language == 'javascript':
suffixes = ['.js']
elif language == 'go':
suffixes = ['.go']
if lang in LANGUAGE:
suffixes = LANGUAGE[lang]
else:
self.log.critical('Language not supported!')
return
for root, dirs, files in os.walk(path):
for file_name in files:
hit = False
for suffix in suffixes:
if file_name.endswith(suffix):
hit = True
break
if hit:
self.source_files_list.append(os.path.join(root, file_name))
for root, _, files in os.walk(path):
self.source_files_list.extend(
os.path.join(root, file) for file in files if any(file.endswith(suffix) for suffix in suffixes)
)
self.log.info(f'Finished loading source files. total files: {len(self.source_files_list)}')

9
audit/language.py Normal file
View File

@@ -0,0 +1,9 @@
LANGUAGE = {
'c': ['.c'],
'c++': ['.cc', '.cpp'],
'go': ['.go'],
'php': ['php', 'php3', 'php4', 'php5', 'phtml'],
'java': ['.java'],
'python': ['.py'],
'javascript': ['.js'],
}

View File

@@ -33,4 +33,5 @@ Important things:
3. When you need to query the content of a code file, please note that you can only query one file at a time. Please follow The above format outputs the absolute path of the file to be queried;
4. After you output the audit results, the user will reply with an empty string. Please make sure that all code files have been audited before ending the audit task;
5. In any case, you must strictly follow the several action formats given above for output. Any content outside the output format is prohibited. Do not try to ask or suggest;
6. When the user prompts "ILLEGAL OUTPUT", it means that your output violates the user's specifications. Please confirm again that all your output must comply with the user's specifications.
"""

View File

@@ -6,7 +6,7 @@ from audit import Audit
warnings.simplefilter('ignore', FutureWarning)
os.environ['OPENAI_API_BASE'] = 'https://yunwu.ai/v1'
os.environ['OPENAI_API_KEY'] = 'sk-SQhmr2wNQa2BpohUrxgJOFIDY9ODSxUkLQLWWlPD9qDNVsN1'
os.environ['OPENAI_API_KEY'] = 'sk-zpkHfWT0Zhvzc79lX11WS4dEyg5CkQ3RdZOSNDoLADaitfVM'
def result_callback(result):
@@ -14,8 +14,8 @@ def result_callback(result):
if __name__ == '__main__':
src_root = r'C:\Users\yvling\Desktop\PHP-Vuln'
language = 'php'
src_root = r'C:\Users\yvling\Desktop\JavaSecLab'
language = 'java'
audit = Audit()
audit.load_source_files(src_root, language)