Add graphical interface

This commit is contained in:
2025-02-09 21:27:07 +08:00
parent aee63534a6
commit 0a9f6d7fcd
11 changed files with 1482 additions and 1107 deletions

View File

@@ -13,33 +13,41 @@ from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
from langchain_text_splitters import CharacterTextSplitter
from audit.rules import FROTIFY_RULES
from logger import Logger
from audit import callback
from audit.prompt import SYSTEM_PROMPT
from audit.language import LANGUAGE
reasoning_model = 'gpt-4o'
embedding_model = 'text-embedding-3-large'
xml_pattern = r'<root>.*?</root>'
class Audit:
def __init__(self, fortify_rules):
def __init__(self, base_url, api_key, reasoning_model, embedding_model, process_output_callback, result_output_callback):
self.raw_chain = None
self.source_files_list = []
self.max_token = 4096
self.fortify_rules = fortify_rules
self.reasoning_model = reasoning_model
self.embedding_model = embedding_model
self.fortify_rules = FROTIFY_RULES
self.process_output_callback = process_output_callback
self.result_output_callback = result_output_callback
self.chat_history = ChatMessageHistory()
self.session_id = uuid.uuid4().hex
self.response_callback = callback.CustomCallbackHandler()
self.embedding = OpenAIEmbeddings(model=embedding_model)
self.embedding = OpenAIEmbeddings(
base_url=base_url,
api_key=api_key,
model=embedding_model
)
self.llm = ChatOpenAI(
base_url=base_url,
api_key=api_key,
model=reasoning_model,
streaming=True,
callbacks=[self.response_callback]
)
self.log = Logger('audit')
self.log = Logger('audit', callback=self.process_output_callback)
self.splitter = CharacterTextSplitter(
chunk_size=300,
chunk_overlap=0,
@@ -65,12 +73,28 @@ class Audit:
('human', '{input}'),
])
def audit(self, callback_function):
self.log.info('Start auditing')
def audit(self, event):
if len(self.source_files_list) <= 0:
self.log.error('没有找到源代码文件')
return
self.log.info('开始代码审计流程')
self.log.info(f'当前推理模型:{self.reasoning_model}')
self.log.info(f'当前嵌入模型:{self.embedding_model}')
input_content = ''
while True:
result = self.send_message(input_content)
if event.is_set():
return
try:
result = self.send_message(input_content)
except Exception as e:
self.log.error(e)
return
if event.is_set():
return
if xml_match := re.search(xml_pattern, result, re.DOTALL):
try:
@@ -80,33 +104,36 @@ class Audit:
action = root.find('action').text
content = root.find('content').text
except Exception as e:
self.log.error(f'Illegal output, try to correct')
print(result)
print(e)
self.log.error(f'动作指令不合法,尝试纠正')
input_content = 'ILLEGAL OUTPUT'
continue
if action == 'QUERY STRUCTURE':
self.log.info('Request project structure')
self.log.info('请求查询项目结构')
input_content = '\n'.join(x for x in self.source_files_list)
continue
elif action == 'QUERY SOURCE':
self.log.info(f'Request source code: {content}')
self.log.info(f'请求查询源代码:{content}')
input_content = open(content, 'r', encoding='utf-8').read()
continue
elif action == 'QUERY FORTIFY':
self.log.info(f'Request fortify: {content}')
self.log.info(f'请求查询规则库:{content}')
input_content = '\n'.join(x for x in self.fortify_rules if x == content)
continue
elif action == 'OUTPUT RESULT':
self.log.warning(f'Audit result: \n\n{content}')
self.log.warning('输出代码审计结果')
self.result_output_callback(content)
self.store_messages_in_faiss(content)
callback_function(content) # Callback function, used to obtain results externally
input_content = ''
input_content = 'ok'
continue
elif action == 'FINISH TASK':
self.log.info(content)
self.log.info('代码审计任务已完成')
return
else:
self.log.critical(f'Unknown action! {action}')
break
self.log.error(f'动作指令未定义:{action}')
return
def send_message(self, input_content):
self.response_callback.temp_content = ''
@@ -140,20 +167,18 @@ class Audit:
text_embedding = self.embedding.embed_query(message)
doc_id = str(uuid.uuid4())
self.messages_db.add_embeddings([(doc_id, text_embedding)], metadatas=[{"id": doc_id}])
self.log.info(f"Audit result stored in messages_db with ID: {doc_id}")
self.log.info(f"代码审计结果已缓存,文档编号:{doc_id}")
def load_source_files(self, path, lang):
self.log.info('Loading source files')
if lang in LANGUAGE:
suffixes = LANGUAGE[lang]
else:
self.log.critical('Language not supported!')
self.log.error('不支持的编程语言')
return
for root, _, files in os.walk(path):
self.source_files_list.extend(
os.path.join(root, file) for file in files if any(file.endswith(suffix) for suffix in suffixes)
os.path.join(root, file).replace('\\', '/') for file in files if any(file.endswith(suffix) for suffix in suffixes)
)
self.log.info(f'Finished loading source files. total files: {len(self.source_files_list)}')
self.log.info(f'源代码文件加载完成,共:{len(self.source_files_list)}')