diff --git a/app/app.go b/app/app.go new file mode 100644 index 0000000..dbd2d44 --- /dev/null +++ b/app/app.go @@ -0,0 +1,5 @@ +package app + +type App interface { + Start() error +} diff --git a/app/client.go b/app/client.go new file mode 100644 index 0000000..81fa1f5 --- /dev/null +++ b/app/client.go @@ -0,0 +1,38 @@ +package app + +import ( + "fmt" + "log" + "net" + "rdp_channel/protocol" + "time" +) + +type Client struct { + Host string + Port int +} + +func NewClient(host string, port int) Client { + return Client{host, port} +} + +func (client Client) Start() error { + addr := fmt.Sprintf("%s:%d", client.Host, client.Port) + + conn, err := net.Dial("tcp", addr) + if err != nil { + return err + } + defer conn.Close() + log.Println("[Client] connected to " + addr) + + tpkt := protocol.NewTPKT(conn) + //fast := protocol.NewFastPath(conn) + //x224 := protocol.NewX224(conn) + + for { + time.Sleep(1 * time.Second) + err = tpkt.Write([]byte("This is a test message.")) + } +} diff --git a/app/client_test.go b/app/client_test.go new file mode 100644 index 0000000..e429325 --- /dev/null +++ b/app/client_test.go @@ -0,0 +1,11 @@ +package app + +import "testing" + +func TestClient(t *testing.T) { + c := NewClient("127.0.0.1", 3388) + err := c.Start() + if err != nil { + t.Fatal(err) + } +} diff --git a/app/server.go b/app/server.go new file mode 100644 index 0000000..e05d701 --- /dev/null +++ b/app/server.go @@ -0,0 +1,54 @@ +package app + +import ( + "fmt" + "log" + "net" + "rdp_channel/protocol" +) + +type Server struct { + Host string + Port int +} + +func NewServer(host string, port int) Server { + return Server{host, port} +} + +func (server Server) Start() error { + addr := fmt.Sprintf("%s:%d", server.Host, server.Port) + + listener, err := net.Listen("tcp", addr) + if err != nil { + return err + } + defer listener.Close() + log.Println("[SERVER] listening on " + addr) + + for { + conn, err := listener.Accept() + if err != nil { + continue + } + go handleConnection(conn) + } +} + +func handleConnection(conn net.Conn) { + defer conn.Close() + log.Println("[SERVER] new connection from " + conn.RemoteAddr().String()) + + tpkt := protocol.NewTPKT(conn) + + //fast := protocol.NewFastPath(conn) + //x224 := protocol.NewX224(conn) + for { + payload, err := tpkt.Read() + if err != nil { + continue + } + + log.Println("[SERVER] received payload: " + string(payload)) + } +} diff --git a/app/server_test.go b/app/server_test.go new file mode 100644 index 0000000..e657c8e --- /dev/null +++ b/app/server_test.go @@ -0,0 +1,11 @@ +package app + +import "testing" + +func TestServer(t *testing.T) { + s := NewServer("0.0.0.0", 3388) + err := s.Start() + if err != nil { + t.Fatal(err) + } +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..0ba51d0 --- /dev/null +++ b/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "flag" + "log" + "rdp_channel/app" +) + +func main() { + mode := flag.String("mode", "server", "server or client") + host := flag.String("host", "127.0.0.1", "server or client") + port := flag.Int("port", 8080, "server or client") + flag.Parse() + + var a app.App + switch *mode { + case "server": + a = app.NewServer(*host, *port) + case "client": + a = app.NewClient(*host, *port) + default: + log.Fatal("[APP] invalid mode: " + *mode) + } + + err := a.Start() + if err != nil { + log.Fatal(err) + } +} diff --git a/protocol/fastpath.go b/protocol/fastpath.go new file mode 100644 index 0000000..aedcc37 --- /dev/null +++ b/protocol/fastpath.go @@ -0,0 +1,230 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "errors" + "io" +) + +/* + 定义协议相关的常量值 +*/ + +// FASTPATH_UPDATE_HEADER +// UpdateHeader(1 byte) +// updateCode(4 bits) | fragmentation(2 bits) | compression(2 bits) +// 参考:https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/a1c4caa8-00ed-45bb-a06e-5177473766d3 +const FASTPATH_UPDATE_HEADER uint8 = 0b00010000 + +/* + FastPath PDU定义 + 参考:https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/0ae3c114-1439-4465-8d3f-6585227eff7d +*/ + +type FastPathPDU struct { + UpdateHeader uint8 + CompressionFlags uint8 + Size uint16 + Payload BitmapUpdateData +} + +/* + Bitmap Data 定义 + 参考:https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/d681bb11-f3b5-4add-b092-19fe7075f9e3 + https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/84a3d4d2-5523-4e49-9a48-33952c559485 +*/ + +type BitmapUpdateData struct { + UpdateType uint16 + NumberRectangles uint16 + Payload BitmapData +} + +type BitmapData struct { + DestLeft uint16 + DestTop uint16 + DestRight uint16 + DestBottom uint16 + Width uint16 + Height uint16 + BitsPerPixel uint16 + Flags uint16 + BitmapLength uint16 + Payload []byte +} + +/* + FastPath协议封装 +*/ + +type FastPath struct { + transport *TPKT +} + +func NewFastPath(conn io.ReadWriter) *FastPath { + return &FastPath{transport: NewTPKT(conn)} +} + +/* + FastPath封包 +*/ + +func (f *FastPath) Write(payload []byte) error { + payloadLen := len(payload) + + // 构造BitmapData + bitmapData := BitmapData{ + DestLeft: 0x0000, + DestTop: 0x0000, + DestRight: 0x000f, // 15 = 0 + 16 -1 + DestBottom: 0x000f, // 15 + Width: 0x0010, // 16 + Height: 0x0010, // 16 + BitsPerPixel: 0x0010, // 16位每像素 + Flags: 0x0000, // 无压缩 + BitmapLength: uint16(payloadLen), + Payload: payload, + } + + // 序列化BitmapData + var bitmapDataBuff bytes.Buffer + binary.Write(&bitmapDataBuff, binary.LittleEndian, bitmapData.DestLeft) + binary.Write(&bitmapDataBuff, binary.LittleEndian, bitmapData.DestTop) + binary.Write(&bitmapDataBuff, binary.LittleEndian, bitmapData.DestRight) + binary.Write(&bitmapDataBuff, binary.LittleEndian, bitmapData.DestBottom) + binary.Write(&bitmapDataBuff, binary.LittleEndian, bitmapData.Width) + binary.Write(&bitmapDataBuff, binary.LittleEndian, bitmapData.Height) + binary.Write(&bitmapDataBuff, binary.LittleEndian, bitmapData.BitsPerPixel) + binary.Write(&bitmapDataBuff, binary.LittleEndian, bitmapData.Flags) + binary.Write(&bitmapDataBuff, binary.LittleEndian, bitmapData.BitmapLength) + bitmapDataBuff.Write(bitmapData.Payload) + + // 构造BitmapUpdateData + bitmapUpdateData := BitmapUpdateData{ + UpdateType: 0x0001, + NumberRectangles: 0x0001, + Payload: bitmapData, + } + + // 序列化BitmapUpdateData + var updateDataBuff bytes.Buffer + binary.Write(&updateDataBuff, binary.LittleEndian, bitmapUpdateData.UpdateType) + binary.Write(&updateDataBuff, binary.LittleEndian, bitmapUpdateData.NumberRectangles) + updateDataBuff.Write(bitmapDataBuff.Bytes()) + + updateDataBytes := updateDataBuff.Bytes() + updateDataLength := len(updateDataBytes) + + // 构造FastPathPDU + var fastPathPDUBuff bytes.Buffer + fastPathPDUBuff.WriteByte(FASTPATH_UPDATE_HEADER) + fastPathPDUBuff.WriteByte(0x00) + + size := uint16(updateDataLength) + if size <= 0x7F { + fastPathPDUBuff.WriteByte(byte(size)) + } else { + fastPathPDUBuff.WriteByte(byte((size >> 8) | 0x80)) + fastPathPDUBuff.WriteByte(byte(size & 0xFF)) + } + + fastPathPDUBuff.Write(updateDataBytes) + return f.transport.Write(fastPathPDUBuff.Bytes()) +} + +/* + FastParh解包 +*/ + +func (f *FastPath) Read() (payload []byte, err error) { + packet, err := f.transport.Read() + if err != nil { + return nil, errors.New("[FASTPATH] read packet error: " + err.Error()) + } + + reader := bytes.NewReader(packet) + + // 解析FastPathPDU + fastPathPDU := FastPathPDU{} + + fastPathPDU.UpdateHeader, err = reader.ReadByte() + if err != nil { + return nil, errors.New("[FASTPATH] read fastpathpdu's update header error: " + err.Error()) + } + + fastPathPDU.CompressionFlags, err = reader.ReadByte() + if err != nil { + return nil, errors.New("[FASTPATH] read fastpathpdu's compression flags error: " + err.Error()) + } + + // 解析Size字段 + lengthByte, err := reader.ReadByte() + if err != nil { + return nil, errors.New("[FASTPATH] read fastpathpdu's size first byte error: " + err.Error()) + } + + var size uint16 + if (lengthByte & 0x80) != 0 { + lengthByte2, err := reader.ReadByte() + if err != nil { + return nil, errors.New("[FASTPATH] read fastpathpdu's size second byte error: " + err.Error()) + } + size = (uint16(lengthByte&0x7F) << 8) | uint16(lengthByte2) + } else { + size = uint16(lengthByte) + } + + fastPathPDU.Size = size + + // 提取FastPathPDU的Payload + fastPathPDUPayload := make([]byte, fastPathPDU.Size) + n, err := reader.Read(fastPathPDUPayload) + if err != nil || n != int(size) { + return nil, errors.New("[FASTPATH] read fastpathpdu's payload error: " + err.Error()) + } + + // 解析BitmapUpdateData + updateReader := bytes.NewReader(fastPathPDUPayload) + var updateType uint16 + if err := binary.Read(updateReader, binary.LittleEndian, &updateType); err != nil { + return nil, errors.New("[FASTPATH] read bitmapupdatedata's update type error: " + err.Error()) + } + + var numRects uint16 + if err := binary.Read(updateReader, binary.LittleEndian, &numRects); err != nil { + return nil, errors.New("[FASTPATH] read bitmapupdatedata's number rectangles error: " + err.Error()) + } + + // 解析BitmapData + var bitmapData BitmapData + fields := []interface{}{ + &bitmapData.DestLeft, + &bitmapData.DestTop, + &bitmapData.DestRight, + &bitmapData.DestBottom, + &bitmapData.Width, + &bitmapData.Height, + &bitmapData.BitsPerPixel, + &bitmapData.Flags, + &bitmapData.BitmapLength, + } + + for _, field := range fields { + if err := binary.Read(updateReader, binary.LittleEndian, field); err != nil { + return nil, errors.New("[FASTPATH] read bitmapdata field error: " + err.Error()) + } + } + + if bitmapData.BitmapLength > uint16(updateReader.Len()) { + return nil, errors.New("[FASTPATH] invalid bitmaplength") + } + + // 提取真实的Payload + bitmapPayload := make([]byte, bitmapData.BitmapLength) + if _, err := updateReader.Read(bitmapPayload); err != nil { + return nil, errors.New("[FASTPATH] read bitmappayload error: " + err.Error()) + } + + return bitmapPayload, nil +} diff --git a/protocol/tpkt.go b/protocol/tpkt.go new file mode 100644 index 0000000..95a0211 --- /dev/null +++ b/protocol/tpkt.go @@ -0,0 +1,91 @@ +package protocol + +import ( + "encoding/binary" + "errors" + "io" +) + +/* + 定义协议相关的常量值 +*/ + +const ( + TPKT_VERSION = 0x03 + TPKT_RESERVED = 0x00 + TPKT_HEADER_LENGTH = 0x04 + TPKT_MAX_PACKET_LENGTH = 0xffff +) + +/* + TPKT结构体封装 +*/ + +type TPKT struct { + conn io.ReadWriter +} + +func NewTPKT(conn io.ReadWriter) *TPKT { + return &TPKT{ + conn: conn, + } +} + +/* + TPKT封包 +*/ + +func (t *TPKT) Write(payload []byte) error { + pduLength := TPKT_HEADER_LENGTH + len(payload) + if pduLength > TPKT_MAX_PACKET_LENGTH { + return errors.New("[TPKT] packet length too long") + } + + // 构造TPKT头(4 bytes): + pdu := make([]byte, pduLength) + pdu[0] = TPKT_VERSION + pdu[1] = TPKT_RESERVED + binary.BigEndian.PutUint16(pdu[2:4], uint16(pduLength)) + + // 装入TPKT载荷 + copy(pdu[4:], payload) + + _, err := t.conn.Write(pdu) + if err != nil { + return errors.New("[TPKT] write error: " + err.Error()) + } + + return nil +} + +/* + TPKT解包 +*/ + +func (t *TPKT) Read() ([]byte, error) { + // 验证TPKT头 + pduHeader := make([]byte, TPKT_HEADER_LENGTH) + _, err := io.ReadFull(t.conn, pduHeader) + if err != nil { + return nil, errors.New("[TPKT] read pdu header error: " + err.Error()) + } + + if pduHeader[0] != TPKT_VERSION { + return nil, errors.New("[TPKT] version mismatch") + } + + pduLength := binary.BigEndian.Uint16(pduHeader[2:4]) + if pduLength > TPKT_MAX_PACKET_LENGTH { + return nil, errors.New("[TPKT] packet length too long") + } + + // 读取TPKT载荷 + payloadLength := pduLength - TPKT_HEADER_LENGTH + payload := make([]byte, payloadLength) + _, err = io.ReadFull(t.conn, payload) + if err != nil { + return nil, errors.New("[TPKT] read pdu payload error: " + err.Error()) + } + + return payload, nil +} diff --git a/protocol/x224.go b/protocol/x224.go new file mode 100644 index 0000000..c8c806d --- /dev/null +++ b/protocol/x224.go @@ -0,0 +1,139 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "errors" + "io" +) + +/* 协议常量 */ + +// X224消息头部长度 +const ( + X224_HEADER_LENGTH = 0x07 +) + +// X224消息类型字段标识 +const ( + X224_CONNECTION_REQUEST byte = 0xE0 + X224_CONNECTION_CONFIRM byte = 0xD0 + X224_DISCONNECT_REQUEST byte = 0x80 + X224_DATA byte = 0xF0 + X224_ERROR byte = 0x70 +) + +// X224 协议封装 +type X224 struct { + transport *TPKT + reqProtocol uint32 + selProtocol uint32 +} + +type X224PDU struct { + Len uint8 + Type byte + DstRef uint16 // 大端序 + SrcRef uint16 // 大端序 + ClsOpt uint8 + Payload []byte +} + +func NewX224(conn io.ReadWriter) *X224 { + return &X224{ + transport: NewTPKT(conn), + } +} + +// 从字节流中解析PDU头部 +func (x *X224) parsePDUHeader(reader *bytes.Reader, pdu *X224PDU) error { + var err error + + // 读取Len字段 + err = binary.Read(reader, binary.BigEndian, &pdu.Len) + if err != nil { + return errors.New("[X224] failed to read pdu length: " + err.Error()) + } + + // 读取Type字段 + err = binary.Read(reader, binary.BigEndian, &pdu.Type) + if err != nil { + return errors.New("[X224] failed to read pdu type: " + err.Error()) + } + + // 读取DstRef(大端序) + err = binary.Read(reader, binary.BigEndian, &pdu.DstRef) + if err != nil { + return errors.New("[X224] failed to read pdu dstref: " + err.Error()) + } + + // 读取SrcRef(大端序) + err = binary.Read(reader, binary.BigEndian, &pdu.SrcRef) + if err != nil { + return errors.New("[X224] failed to read pdu srcref: " + err.Error()) + } + + // 读取ClsOpt + err = binary.Read(reader, binary.BigEndian, &pdu.ClsOpt) + if err != nil { + return errors.New("[X224] failed to read pdu clsopt: " + err.Error()) + } + + return nil +} + +// 序列化X224PDU +func (x *X224) serializeX224PDU(pdu *X224PDU) []byte { + buff := bytes.NewBuffer(nil) + _ = binary.Write(buff, binary.BigEndian, pdu.Len) + _ = binary.Write(buff, binary.BigEndian, pdu.Type) + _ = binary.Write(buff, binary.BigEndian, pdu.DstRef) + _ = binary.Write(buff, binary.BigEndian, pdu.SrcRef) + _ = binary.Write(buff, binary.BigEndian, pdu.ClsOpt) + + buff.Write(pdu.Payload) + return buff.Bytes() +} + +// 封包 +func (x *X224) Write(payload []byte) error { + pdu := &X224PDU{ + Len: uint8(X224_HEADER_LENGTH + len(payload)), // 头部长度 + 载荷字段 + Type: X224_DATA, + DstRef: 0xf0, + SrcRef: 0xf1, + ClsOpt: 0x0, + Payload: payload, + } + + payloadBytes := x.serializeX224PDU(pdu) + err := x.transport.Write(payloadBytes) + if err != nil { + return errors.New("[X224] failed to write: " + err.Error()) + } + + return nil +} + +// 解包 +func (x *X224) Read() ([]byte, error) { + packet, err := x.transport.Read() + if err != nil { + return nil, errors.New("[X224] failed to read: " + err.Error()) + } + + pdu := &X224PDU{} + reader := bytes.NewReader(packet) + + err = x.parsePDUHeader(reader, pdu) + if err != nil { + return nil, errors.New("[X224] failed to parse pdu header: " + err.Error()) + } + + payload, err := io.ReadAll(reader) + if err != nil { + return nil, errors.New("[X224] failed to read payload: " + err.Error()) + } + + return payload, nil +}