diff --git a/protocol/fastpath/fastpath.go b/protocol/fastpath/fastpath.go new file mode 100644 index 0000000..70754b7 --- /dev/null +++ b/protocol/fastpath/fastpath.go @@ -0,0 +1,68 @@ +package fastpath + +import ( + "encoding/binary" + "errors" + "rdp_channel/protocol/core/transport" +) + +// 协议常量 +const ( + FASTPATH_PDU_HEADER_LENGTH = 0x4 // updateHeader(1 bytes) + compressionFlags(1 bytes) + size(2 bytes) + FASTPATH__MAX_PACKET_LENGTH = 0xffff + + FASTPATH_UPDATETYPE_ORDERS uint32 = 0x0 + FASTPATH_UPDATETYPE_BITMAP uint32 = 0x1 +) + +const ( + // update header: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/a1c4caa8-00ed-45bb-a06e-5177473766d3 + FASTPATH_UPDATE_HEADER uint8 = 0x8 + FASTPATH_COMPRESSION_FLAGS uint8 = 0x0 +) + +var ( + FASTPATH_INVALID_PACKET_LENGTH = errors.New("[FASTPATH] invalid packet length") +) + +type FastPath struct { + transport transport.Transport +} + +func New(transport transport.Transport) *FastPath { + return &FastPath{ + transport: transport, + } +} + +func (fp *FastPath) Write(data []byte) (int, error) { + /* + 构造PDU + 格式参考:https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/0ae3c114-1439-4465-8d3f-6585227eff7d + */ + dataLen := len(data) + if dataLen > FASTPATH__MAX_PACKET_LENGTH { + return dataLen, FASTPATH_INVALID_PACKET_LENGTH + } + + size := uint16(FASTPATH_PDU_HEADER_LENGTH + dataLen) + pdu := make([]byte, size) + + pdu[0] = FASTPATH_UPDATE_HEADER + pdu[1] = FASTPATH_COMPRESSION_FLAGS + binary.LittleEndian.PutUint16(pdu[2:4], size) + + copy(pdu[4:], data) + + return fp.transport.Write(pdu) +} + +func (fp *FastPath) Read() (int, []byte, error) { + _, packet, err := fp.transport.Read() + if err != nil { + return 0, nil, err + } + + data := packet[FASTPATH_PDU_HEADER_LENGTH:] + return len(data), data, nil +} diff --git a/protocol/fastpath/fastpath_test.go b/protocol/fastpath/fastpath_test.go new file mode 100644 index 0000000..0f18486 --- /dev/null +++ b/protocol/fastpath/fastpath_test.go @@ -0,0 +1,66 @@ +package fastpath + +import ( + "net" + "rdp_channel/protocol/tpkt" + "testing" +) + +func TestFastPath(t *testing.T) { + go runServer(t) + runClient(t) +} + +func runServer(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:3388") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + for { + conn, err := listener.Accept() + if err != nil { + continue + } + + go func(conn net.Conn) { + defer conn.Close() + + tpkt := tpkt.New(conn) + fp := New(tpkt) + + dataLen, data, err := fp.Read() + if err != nil { + return + } + + t.Logf("fp server read data(%d bytes): %q\n", dataLen, data) + + _, err = fp.Write([]byte("fp server hello")) + }(conn) + } +} + +func runClient(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:3388") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + tpkt := tpkt.New(conn) + fp := New(tpkt) + + _, err = fp.Write([]byte("fp client hello")) + if err != nil { + t.Logf("fp client write error: %s\n", err) + } + + dataLen, data, err := fp.Read() + if err != nil { + t.Fatal(err) + } + + t.Logf("fp client read data(%d bytes): %q\n", dataLen, data) +}