diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..44cc41c --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module rdp_channel + +go 1.24 diff --git a/protocol/tpkt/tpkt.go b/protocol/tpkt/tpkt.go new file mode 100644 index 0000000..f7551ad --- /dev/null +++ b/protocol/tpkt/tpkt.go @@ -0,0 +1,86 @@ +package tpkt + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "net" +) + +// 协议常量 +const ( + TPKT_VERSION = 0x03 + TPKT_RESERVED = 0x00 + TPKT_HEADER_LENGTH = 0x04 + TPKT_MAX_PACKET_LENGTH = 0xffff +) + +var ( + TPKT_INVALID_VERSION = errors.New("[TPKT] invalid version") + TPKT_INVALID_PACKET_LENGTH = errors.New("[TPKT] invalid packet length") +) + +// TPKT 协议封装 +type TPKT struct { + connection net.Conn + readBuff *bufio.Reader +} + +func New(conn net.Conn) *TPKT { + tpkt := &TPKT{ + connection: conn, + readBuff: bufio.NewReader(conn), + } + + return tpkt +} + +// Write 发送TPKT包 +func (tpkt *TPKT) Write(data []byte) (int, error) { + dataLen := len(data) + if dataLen > (TPKT_MAX_PACKET_LENGTH - TPKT_HEADER_LENGTH) { + return 0, TPKT_INVALID_PACKET_LENGTH + } + + // TPKT封包 + packet := make([]byte, TPKT_HEADER_LENGTH+dataLen) + packet[0] = TPKT_VERSION + packet[1] = TPKT_RESERVED + binary.BigEndian.PutUint16(packet[2:4], uint16(TPKT_HEADER_LENGTH+dataLen)) // 第2、3字节为tpkt包的长度 + + // 装填载荷 + copy(packet[TPKT_HEADER_LENGTH:], data) + + // 发送TPKT包 + return tpkt.connection.Write(packet) +} + +func (tpkt *TPKT) Read() (int, []byte, error) { + // 验证TPKT头 + header := make([]byte, TPKT_HEADER_LENGTH) + _, err := io.ReadFull(tpkt.readBuff, header) + if err != nil { + return 0, nil, err + } + + // 验证TPKT版本 + if header[0] != TPKT_VERSION { + return 0, nil, TPKT_INVALID_VERSION + } + + // 验证载荷长度 + length := binary.BigEndian.Uint16(header[2:4]) + if length > TPKT_MAX_PACKET_LENGTH { + return 0, nil, TPKT_INVALID_PACKET_LENGTH + } + + // 读取载荷数据 + dataLen := length - TPKT_HEADER_LENGTH + data := make([]byte, dataLen) + if _, err := io.ReadFull(tpkt.readBuff, data); err != nil { + return 0, nil, err + } + + return int(dataLen), data, nil +} diff --git a/protocol/tpkt/tpkt_test.go b/protocol/tpkt/tpkt_test.go new file mode 100644 index 0000000..a683958 --- /dev/null +++ b/protocol/tpkt/tpkt_test.go @@ -0,0 +1,68 @@ +package tpkt + +import ( + "net" + "testing" +) + +func TestTpkt(t *testing.T) { + go runServer(t) + runClient(t) +} + +func runServer(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:3388") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + t.Logf("tpkt server listening at %s\n", listener.Addr()) + + for { + conn, err := listener.Accept() + if err != nil { + t.Logf("tpkt server accept error: %s\n", err) + continue + } + + t.Logf("tpkt server accepted connection from %s\n", conn.RemoteAddr()) + go func(conn net.Conn) { + defer conn.Close() + + tpkt := New(conn) + dataLen, data, err := tpkt.Read() + if err != nil { + t.Logf("tpkt server read error: %s\n", err) + return + } + + t.Logf("tpkt server read data(%d): %q\n", dataLen, data) + + _, err = tpkt.Write([]byte("tpkt server hello")) + if err != nil { + t.Logf("tpkt server write error: %s\n", err) + } + }(conn) + } +} + +func runClient(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:3388") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + tpkt := New(conn) + + _, err = tpkt.Write([]byte("tpkt client hello")) + if err != nil { + t.Logf("tpkt client write error: %s\n", err) + } + + dataLen, data, err := tpkt.Read() + if err != nil { + t.Logf("tpkt client read error: %s\n", err) + } + t.Logf("tpkt client read data(%d): %q\n", dataLen, data) +}