Add context compression mechanism

This commit is contained in:
2025-02-09 12:05:41 +08:00
parent 32e61b8d24
commit 6a0cab65d1
4 changed files with 76 additions and 41 deletions

View File

@@ -2,17 +2,23 @@ import os
import re import re
import time import time
import uuid import uuid
import tiktoken
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain_core.messages import SystemMessage from langchain_core.messages import SystemMessage
from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.vectorstores import FAISS
from langchain_community.chat_message_histories import ChatMessageHistory from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
from langchain_text_splitters import CharacterTextSplitter
from logger import Logger from logger import Logger
from audit import callback from audit import callback
from audit.prompt import SYSTEM_PROMPT from audit.prompt import SYSTEM_PROMPT
from audit.language import LANGUAGE
reasoning_model = 'gemini-2.0-flash-thinking-exp' reasoning_model = 'gemini-2.0-flash-thinking-exp'
embedding_model = 'text-embedding-3-large' embedding_model = 'text-embedding-3-large'
@@ -24,12 +30,36 @@ class Audit:
def __init__(self): def __init__(self):
self.raw_chain = None self.raw_chain = None
self.source_files_list = [] self.source_files_list = []
self.max_token = 4096
self.chat_history = ChatMessageHistory() self.chat_history = ChatMessageHistory()
self.session_id = uuid.uuid4().hex self.session_id = uuid.uuid4().hex
self.response_callback = callback.CustomCallbackHandler() self.response_callback = callback.CustomCallbackHandler()
self.embedding = OpenAIEmbeddings(model=embedding_model) self.embedding = OpenAIEmbeddings(model=embedding_model)
self.llm = ChatOpenAI(model=reasoning_model, streaming=True, callbacks=[self.response_callback]) self.llm = ChatOpenAI(
model=reasoning_model,
streaming=True,
callbacks=[self.response_callback]
)
self.log = Logger('audit') self.log = Logger('audit')
self.splitter = CharacterTextSplitter(
chunk_size=300,
chunk_overlap=0,
separator=". ",
)
self.messages_db = FAISS.from_texts(['nothing'], self.embedding)
self.retriever = self.messages_db.as_retriever()
self.redundant_filter = EmbeddingsRedundantFilter(embeddings=self.embedding)
self.relevant_filter = EmbeddingsFilter(
embeddings=self.embedding,
similarity_threshold=0.76,
)
self.pipeline_compressor = DocumentCompressorPipeline(
transformers=[self.splitter, self.redundant_filter, self.relevant_filter]
)
self.compression_retriever = ContextualCompressionRetriever(
base_compressor=self.pipeline_compressor,
base_retriever=self.retriever,
)
self.prompt = ChatPromptTemplate.from_messages([ self.prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=SYSTEM_PROMPT), SystemMessage(content=SYSTEM_PROMPT),
MessagesPlaceholder(variable_name='messages'), MessagesPlaceholder(variable_name='messages'),
@@ -41,19 +71,22 @@ class Audit:
input_content = '' input_content = ''
while True: while True:
time.sleep(3)
result = self.send_message(input_content) result = self.send_message(input_content)
xml_match = re.search(xml_pattern, result, re.DOTALL)
if xml_match: if xml_match := re.search(xml_pattern, result, re.DOTALL):
xml_content = xml_match.group(0) try:
root = ET.fromstring(xml_content) xml_content = xml_match.group(0)
root = ET.fromstring(xml_content)
action = root.find('action').text action = root.find('action').text
content = root.find('content').text content = root.find('content').text
except Exception as e:
self.log.error(f'Illegal output, try to correct')
input_content = 'ILLEGAL OUTPUT'
continue
if action == 'QUERY STRUCTURE': if action == 'QUERY STRUCTURE':
self.log.info('Request to query project structure') self.log.info('Request project structure')
input_content = '\n'.join(x for x in self.source_files_list) input_content = '\n'.join(x for x in self.source_files_list)
continue continue
elif action == 'QUERY SOURCE': elif action == 'QUERY SOURCE':
@@ -61,7 +94,8 @@ class Audit:
input_content = open(content, 'r', encoding='utf-8').read() input_content = open(content, 'r', encoding='utf-8').read()
continue continue
elif action == 'OUTPUT RESULT': elif action == 'OUTPUT RESULT':
self.log.warning(f'Audit result: \n{content}\n') self.log.warning(f'Audit result: \n\n{content}')
self.store_messages_in_faiss(content)
callback_function(content) # Callback function, used to obtain results externally callback_function(content) # Callback function, used to obtain results externally
input_content = '' input_content = ''
continue continue
@@ -71,13 +105,14 @@ class Audit:
def send_message(self, input_content): def send_message(self, input_content):
self.response_callback.temp_content = '' self.response_callback.temp_content = ''
compressed_context = self.compression_retriever.invoke(input_content)
if input_content == '': if input_content == '':
input_content = 'nothing' input_content = 'nothing'
input_dict = { input_dict = {
'input': input_content, 'input': input_content,
'context': '', 'context': compressed_context,
} }
config_dict = { config_dict = {
'configurable': {'session_id': self.session_id} 'configurable': {'session_id': self.session_id}
@@ -91,41 +126,31 @@ class Audit:
history_messages_key='messages', history_messages_key='messages',
) )
self.log.debug(f'Chat messages: {input_dict}')
for _ in chain_with_message_history.stream(input_dict, config_dict): for _ in chain_with_message_history.stream(input_dict, config_dict):
pass pass
return self.response_callback.temp_content return self.response_callback.temp_content
def load_source_files(self, path, language): def store_messages_in_faiss(self, message):
text_embedding = self.embedding.embed_query(message)
doc_id = str(uuid.uuid4())
self.messages_db.add_embeddings([(doc_id, text_embedding)], metadatas=[{"id": doc_id}])
self.log.info(f"Audit result stored in messages_db with ID: {doc_id}")
def load_source_files(self, path, lang):
self.log.info('Loading source files') self.log.info('Loading source files')
if language == 'php': if lang in LANGUAGE:
suffixes = ['.php', '.php3', 'php4', 'php5'] suffixes = LANGUAGE[lang]
elif language == 'python':
suffixes = ['.py']
elif language == 'java':
suffixes = ['.java']
elif language == 'c':
suffixes = ['.c']
elif language == 'c++':
suffixes = ['.cpp', 'cc']
elif language == 'javascript':
suffixes = ['.js']
elif language == 'go':
suffixes = ['.go']
else: else:
self.log.critical('Language not supported!') self.log.critical('Language not supported!')
return return
for root, dirs, files in os.walk(path): for root, _, files in os.walk(path):
for file_name in files: self.source_files_list.extend(
hit = False os.path.join(root, file) for file in files if any(file.endswith(suffix) for suffix in suffixes)
for suffix in suffixes: )
if file_name.endswith(suffix):
hit = True
break
if hit:
self.source_files_list.append(os.path.join(root, file_name))
self.log.info(f'Finished loading source files. total files: {len(self.source_files_list)}') self.log.info(f'Finished loading source files. total files: {len(self.source_files_list)}')

9
audit/language.py Normal file
View File

@@ -0,0 +1,9 @@
LANGUAGE = {
'c': ['.c'],
'c++': ['.cc', '.cpp'],
'go': ['.go'],
'php': ['php', 'php3', 'php4', 'php5', 'phtml'],
'java': ['.java'],
'python': ['.py'],
'javascript': ['.js'],
}

View File

@@ -33,4 +33,5 @@ Important things:
3. When you need to query the content of a code file, please note that you can only query one file at a time. Please follow The above format outputs the absolute path of the file to be queried; 3. When you need to query the content of a code file, please note that you can only query one file at a time. Please follow The above format outputs the absolute path of the file to be queried;
4. After you output the audit results, the user will reply with an empty string. Please make sure that all code files have been audited before ending the audit task; 4. After you output the audit results, the user will reply with an empty string. Please make sure that all code files have been audited before ending the audit task;
5. In any case, you must strictly follow the several action formats given above for output. Any content outside the output format is prohibited. Do not try to ask or suggest; 5. In any case, you must strictly follow the several action formats given above for output. Any content outside the output format is prohibited. Do not try to ask or suggest;
6. When the user prompts "ILLEGAL OUTPUT", it means that your output violates the user's specifications. Please confirm again that all your output must comply with the user's specifications.
""" """

View File

@@ -6,7 +6,7 @@ from audit import Audit
warnings.simplefilter('ignore', FutureWarning) warnings.simplefilter('ignore', FutureWarning)
os.environ['OPENAI_API_BASE'] = 'https://yunwu.ai/v1' os.environ['OPENAI_API_BASE'] = 'https://yunwu.ai/v1'
os.environ['OPENAI_API_KEY'] = 'sk-SQhmr2wNQa2BpohUrxgJOFIDY9ODSxUkLQLWWlPD9qDNVsN1' os.environ['OPENAI_API_KEY'] = 'sk-zpkHfWT0Zhvzc79lX11WS4dEyg5CkQ3RdZOSNDoLADaitfVM'
def result_callback(result): def result_callback(result):
@@ -14,8 +14,8 @@ def result_callback(result):
if __name__ == '__main__': if __name__ == '__main__':
src_root = r'C:\Users\yvling\Desktop\PHP-Vuln' src_root = r'C:\Users\yvling\Desktop\JavaSecLab'
language = 'php' language = 'java'
audit = Audit() audit = Audit()
audit.load_source_files(src_root, language) audit.load_source_files(src_root, language)