基本完成TPKT协议

This commit is contained in:
2025-03-30 00:31:57 +08:00
parent 16ee0ca642
commit 550abb8333
3 changed files with 157 additions and 0 deletions

3
go.mod Normal file
View File

@@ -0,0 +1,3 @@
module rdp_channel
go 1.24

86
protocol/tpkt/tpkt.go Normal file
View File

@@ -0,0 +1,86 @@
package tpkt
import (
"bufio"
"encoding/binary"
"errors"
"io"
"net"
)
// 协议常量
const (
TPKT_VERSION = 0x03
TPKT_RESERVED = 0x00
TPKT_HEADER_LENGTH = 0x04
TPKT_MAX_PACKET_LENGTH = 0xffff
)
var (
TPKT_INVALID_VERSION = errors.New("[TPKT] invalid version")
TPKT_INVALID_PACKET_LENGTH = errors.New("[TPKT] invalid packet length")
)
// TPKT 协议封装
type TPKT struct {
connection net.Conn
readBuff *bufio.Reader
}
func New(conn net.Conn) *TPKT {
tpkt := &TPKT{
connection: conn,
readBuff: bufio.NewReader(conn),
}
return tpkt
}
// Write 发送TPKT包
func (tpkt *TPKT) Write(data []byte) (int, error) {
dataLen := len(data)
if dataLen > (TPKT_MAX_PACKET_LENGTH - TPKT_HEADER_LENGTH) {
return 0, TPKT_INVALID_PACKET_LENGTH
}
// TPKT封包
packet := make([]byte, TPKT_HEADER_LENGTH+dataLen)
packet[0] = TPKT_VERSION
packet[1] = TPKT_RESERVED
binary.BigEndian.PutUint16(packet[2:4], uint16(TPKT_HEADER_LENGTH+dataLen)) // 第2、3字节为tpkt包的长度
// 装填载荷
copy(packet[TPKT_HEADER_LENGTH:], data)
// 发送TPKT包
return tpkt.connection.Write(packet)
}
func (tpkt *TPKT) Read() (int, []byte, error) {
// 验证TPKT头
header := make([]byte, TPKT_HEADER_LENGTH)
_, err := io.ReadFull(tpkt.readBuff, header)
if err != nil {
return 0, nil, err
}
// 验证TPKT版本
if header[0] != TPKT_VERSION {
return 0, nil, TPKT_INVALID_VERSION
}
// 验证载荷长度
length := binary.BigEndian.Uint16(header[2:4])
if length > TPKT_MAX_PACKET_LENGTH {
return 0, nil, TPKT_INVALID_PACKET_LENGTH
}
// 读取载荷数据
dataLen := length - TPKT_HEADER_LENGTH
data := make([]byte, dataLen)
if _, err := io.ReadFull(tpkt.readBuff, data); err != nil {
return 0, nil, err
}
return int(dataLen), data, nil
}

View File

@@ -0,0 +1,68 @@
package tpkt
import (
"net"
"testing"
)
func TestTpkt(t *testing.T) {
go runServer(t)
runClient(t)
}
func runServer(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:3388")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
t.Logf("tpkt server listening at %s\n", listener.Addr())
for {
conn, err := listener.Accept()
if err != nil {
t.Logf("tpkt server accept error: %s\n", err)
continue
}
t.Logf("tpkt server accepted connection from %s\n", conn.RemoteAddr())
go func(conn net.Conn) {
defer conn.Close()
tpkt := New(conn)
dataLen, data, err := tpkt.Read()
if err != nil {
t.Logf("tpkt server read error: %s\n", err)
return
}
t.Logf("tpkt server read data(%d): %q\n", dataLen, data)
_, err = tpkt.Write([]byte("tpkt server hello"))
if err != nil {
t.Logf("tpkt server write error: %s\n", err)
}
}(conn)
}
}
func runClient(t *testing.T) {
conn, err := net.Dial("tcp", "127.0.0.1:3388")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
tpkt := New(conn)
_, err = tpkt.Write([]byte("tpkt client hello"))
if err != nil {
t.Logf("tpkt client write error: %s\n", err)
}
dataLen, data, err := tpkt.Read()
if err != nil {
t.Logf("tpkt client read error: %s\n", err)
}
t.Logf("tpkt client read data(%d): %q\n", dataLen, data)
}