mirror of
https://github.com/yv1ing/MollyAudit.git
synced 2025-09-16 14:55:50 +08:00
Add context compression mechanism
This commit is contained in:
@@ -2,17 +2,23 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import tiktoken
|
||||
import xml.etree.ElementTree as ET
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_community.chat_message_histories import ChatMessageHistory
|
||||
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
|
||||
from logger import Logger
|
||||
from audit import callback
|
||||
from audit.prompt import SYSTEM_PROMPT
|
||||
from audit.language import LANGUAGE
|
||||
|
||||
reasoning_model = 'gemini-2.0-flash-thinking-exp'
|
||||
embedding_model = 'text-embedding-3-large'
|
||||
@@ -24,12 +30,36 @@ class Audit:
|
||||
def __init__(self):
|
||||
self.raw_chain = None
|
||||
self.source_files_list = []
|
||||
self.max_token = 4096
|
||||
self.chat_history = ChatMessageHistory()
|
||||
self.session_id = uuid.uuid4().hex
|
||||
self.response_callback = callback.CustomCallbackHandler()
|
||||
self.embedding = OpenAIEmbeddings(model=embedding_model)
|
||||
self.llm = ChatOpenAI(model=reasoning_model, streaming=True, callbacks=[self.response_callback])
|
||||
self.llm = ChatOpenAI(
|
||||
model=reasoning_model,
|
||||
streaming=True,
|
||||
callbacks=[self.response_callback]
|
||||
)
|
||||
self.log = Logger('audit')
|
||||
self.splitter = CharacterTextSplitter(
|
||||
chunk_size=300,
|
||||
chunk_overlap=0,
|
||||
separator=". ",
|
||||
)
|
||||
self.messages_db = FAISS.from_texts(['nothing'], self.embedding)
|
||||
self.retriever = self.messages_db.as_retriever()
|
||||
self.redundant_filter = EmbeddingsRedundantFilter(embeddings=self.embedding)
|
||||
self.relevant_filter = EmbeddingsFilter(
|
||||
embeddings=self.embedding,
|
||||
similarity_threshold=0.76,
|
||||
)
|
||||
self.pipeline_compressor = DocumentCompressorPipeline(
|
||||
transformers=[self.splitter, self.redundant_filter, self.relevant_filter]
|
||||
)
|
||||
self.compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=self.pipeline_compressor,
|
||||
base_retriever=self.retriever,
|
||||
)
|
||||
self.prompt = ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content=SYSTEM_PROMPT),
|
||||
MessagesPlaceholder(variable_name='messages'),
|
||||
@@ -41,19 +71,22 @@ class Audit:
|
||||
|
||||
input_content = ''
|
||||
while True:
|
||||
time.sleep(3)
|
||||
result = self.send_message(input_content)
|
||||
xml_match = re.search(xml_pattern, result, re.DOTALL)
|
||||
|
||||
if xml_match:
|
||||
xml_content = xml_match.group(0)
|
||||
root = ET.fromstring(xml_content)
|
||||
if xml_match := re.search(xml_pattern, result, re.DOTALL):
|
||||
try:
|
||||
xml_content = xml_match.group(0)
|
||||
root = ET.fromstring(xml_content)
|
||||
|
||||
action = root.find('action').text
|
||||
content = root.find('content').text
|
||||
action = root.find('action').text
|
||||
content = root.find('content').text
|
||||
except Exception as e:
|
||||
self.log.error(f'Illegal output, try to correct')
|
||||
input_content = 'ILLEGAL OUTPUT'
|
||||
continue
|
||||
|
||||
if action == 'QUERY STRUCTURE':
|
||||
self.log.info('Request to query project structure')
|
||||
self.log.info('Request project structure')
|
||||
input_content = '\n'.join(x for x in self.source_files_list)
|
||||
continue
|
||||
elif action == 'QUERY SOURCE':
|
||||
@@ -61,7 +94,8 @@ class Audit:
|
||||
input_content = open(content, 'r', encoding='utf-8').read()
|
||||
continue
|
||||
elif action == 'OUTPUT RESULT':
|
||||
self.log.warning(f'Audit result: \n{content}\n')
|
||||
self.log.warning(f'Audit result: \n\n{content}')
|
||||
self.store_messages_in_faiss(content)
|
||||
callback_function(content) # Callback function, used to obtain results externally
|
||||
input_content = ''
|
||||
continue
|
||||
@@ -71,13 +105,14 @@ class Audit:
|
||||
|
||||
def send_message(self, input_content):
|
||||
self.response_callback.temp_content = ''
|
||||
compressed_context = self.compression_retriever.invoke(input_content)
|
||||
|
||||
if input_content == '':
|
||||
input_content = 'nothing'
|
||||
|
||||
input_dict = {
|
||||
'input': input_content,
|
||||
'context': '',
|
||||
'context': compressed_context,
|
||||
}
|
||||
config_dict = {
|
||||
'configurable': {'session_id': self.session_id}
|
||||
@@ -91,41 +126,31 @@ class Audit:
|
||||
history_messages_key='messages',
|
||||
)
|
||||
|
||||
self.log.debug(f'Chat messages: {input_dict}')
|
||||
|
||||
for _ in chain_with_message_history.stream(input_dict, config_dict):
|
||||
pass
|
||||
|
||||
return self.response_callback.temp_content
|
||||
|
||||
def load_source_files(self, path, language):
|
||||
def store_messages_in_faiss(self, message):
|
||||
text_embedding = self.embedding.embed_query(message)
|
||||
doc_id = str(uuid.uuid4())
|
||||
self.messages_db.add_embeddings([(doc_id, text_embedding)], metadatas=[{"id": doc_id}])
|
||||
self.log.info(f"Audit result stored in messages_db with ID: {doc_id}")
|
||||
|
||||
def load_source_files(self, path, lang):
|
||||
self.log.info('Loading source files')
|
||||
|
||||
if language == 'php':
|
||||
suffixes = ['.php', '.php3', 'php4', 'php5']
|
||||
elif language == 'python':
|
||||
suffixes = ['.py']
|
||||
elif language == 'java':
|
||||
suffixes = ['.java']
|
||||
elif language == 'c':
|
||||
suffixes = ['.c']
|
||||
elif language == 'c++':
|
||||
suffixes = ['.cpp', 'cc']
|
||||
elif language == 'javascript':
|
||||
suffixes = ['.js']
|
||||
elif language == 'go':
|
||||
suffixes = ['.go']
|
||||
if lang in LANGUAGE:
|
||||
suffixes = LANGUAGE[lang]
|
||||
else:
|
||||
self.log.critical('Language not supported!')
|
||||
return
|
||||
|
||||
for root, dirs, files in os.walk(path):
|
||||
for file_name in files:
|
||||
hit = False
|
||||
for suffix in suffixes:
|
||||
if file_name.endswith(suffix):
|
||||
hit = True
|
||||
break
|
||||
|
||||
if hit:
|
||||
self.source_files_list.append(os.path.join(root, file_name))
|
||||
for root, _, files in os.walk(path):
|
||||
self.source_files_list.extend(
|
||||
os.path.join(root, file) for file in files if any(file.endswith(suffix) for suffix in suffixes)
|
||||
)
|
||||
|
||||
self.log.info(f'Finished loading source files. total files: {len(self.source_files_list)}')
|
||||
|
||||
9
audit/language.py
Normal file
9
audit/language.py
Normal file
@@ -0,0 +1,9 @@
|
||||
LANGUAGE = {
|
||||
'c': ['.c'],
|
||||
'c++': ['.cc', '.cpp'],
|
||||
'go': ['.go'],
|
||||
'php': ['php', 'php3', 'php4', 'php5', 'phtml'],
|
||||
'java': ['.java'],
|
||||
'python': ['.py'],
|
||||
'javascript': ['.js'],
|
||||
}
|
||||
@@ -33,4 +33,5 @@ Important things:
|
||||
3. When you need to query the content of a code file, please note that you can only query one file at a time. Please follow The above format outputs the absolute path of the file to be queried;
|
||||
4. After you output the audit results, the user will reply with an empty string. Please make sure that all code files have been audited before ending the audit task;
|
||||
5. In any case, you must strictly follow the several action formats given above for output. Any content outside the output format is prohibited. Do not try to ask or suggest;
|
||||
6. When the user prompts "ILLEGAL OUTPUT", it means that your output violates the user's specifications. Please confirm again that all your output must comply with the user's specifications.
|
||||
"""
|
||||
|
||||
6
main.py
6
main.py
@@ -6,7 +6,7 @@ from audit import Audit
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
|
||||
os.environ['OPENAI_API_BASE'] = 'https://yunwu.ai/v1'
|
||||
os.environ['OPENAI_API_KEY'] = 'sk-SQhmr2wNQa2BpohUrxgJOFIDY9ODSxUkLQLWWlPD9qDNVsN1'
|
||||
os.environ['OPENAI_API_KEY'] = 'sk-zpkHfWT0Zhvzc79lX11WS4dEyg5CkQ3RdZOSNDoLADaitfVM'
|
||||
|
||||
|
||||
def result_callback(result):
|
||||
@@ -14,8 +14,8 @@ def result_callback(result):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
src_root = r'C:\Users\yvling\Desktop\PHP-Vuln'
|
||||
language = 'php'
|
||||
src_root = r'C:\Users\yvling\Desktop\JavaSecLab'
|
||||
language = 'java'
|
||||
|
||||
audit = Audit()
|
||||
audit.load_source_files(src_root, language)
|
||||
|
||||
Reference in New Issue
Block a user