diff --git a/protocol/core/transport/transport.go b/protocol/core/transport/transport.go new file mode 100644 index 0000000..114339f --- /dev/null +++ b/protocol/core/transport/transport.go @@ -0,0 +1,6 @@ +package transport + +type Transport interface { + Read() (int, []byte, error) + Write([]byte) (int, error) +} diff --git a/protocol/x224/x224.go b/protocol/x224/x224.go index bf8395e..c4f1acc 100644 --- a/protocol/x224/x224.go +++ b/protocol/x224/x224.go @@ -4,8 +4,9 @@ import ( "bytes" "encoding/binary" "errors" - "fmt" - "rdp_channel/protocol/tpkt" + "io" + "log" + "rdp_channel/protocol/core/transport" ) /* 协议常量 */ @@ -44,17 +45,6 @@ type Negotiation struct { Payload uint32 } -// 格式化输出 -func (neg *Negotiation) String() string { - buff := &bytes.Buffer{} - buff.WriteString(fmt.Sprintf(" Type: 0x%02X\n", neg.Type)) - buff.WriteString(fmt.Sprintf(" Flags: 0x%02X\n", neg.Flags)) - buff.WriteString(fmt.Sprintf(" Length: 0x%04X (%d bytes)\n", neg.Length, neg.Length)) - buff.WriteString(fmt.Sprintf(" Payload: 0x%08X\n", neg.Payload)) - - return buff.String() -} - func (neg *Negotiation) parseNegotiation(reader *bytes.Reader) error { var err error @@ -83,63 +73,45 @@ func (neg *Negotiation) parseNegotiation(reader *bytes.Reader) error { // X224 协议封装 type X224 struct { - transport *tpkt.TPKT + transport transport.Transport reqProtocol uint32 selProtocol uint32 + + dataHandlers []func([]byte) + errorHandlers []func(error) } type X224PDU struct { - Len uint8 - Type byte - DstRef uint16 // 大端序 - SrcRef uint16 // 大端序 - ClsOpt uint8 - Cookie []byte - NegPayload *Negotiation + Len uint8 + Type byte + DstRef uint16 // 大端序 + SrcRef uint16 // 大端序 + ClsOpt uint8 + Cookie []byte + Payload []byte + NegMsg *Negotiation } -// 格式化输出 -func (pdu *X224PDU) String() string { - buff := &bytes.Buffer{} - buff.WriteString("X224PDU {\n") - buff.WriteString(fmt.Sprintf(" Len: 0x%02X (%d bytes)\n", pdu.Len, pdu.Len)) - buff.WriteString(fmt.Sprintf(" Type: 0x%02X\n", pdu.Type)) - buff.WriteString(fmt.Sprintf(" DstRef: 0x%04X (BigEndian)\n", pdu.DstRef)) - buff.WriteString(fmt.Sprintf(" SrcRef: 0x%04X (BigEndian)\n", pdu.SrcRef)) - buff.WriteString(fmt.Sprintf(" ClsOpt: 0x%02X\n", pdu.ClsOpt)) - buff.WriteString(" Cookie: ") - if len(pdu.Cookie) > 0 { - buff.WriteString(fmt.Sprintf("%q\n", string(pdu.Cookie))) - } else { - buff.WriteString("(empty)\n") - } - if pdu.NegPayload != nil { - buff.WriteString(" NegPayload: {\n") - buff.WriteString(pdu.NegPayload.String()) - buff.WriteString(" }\n") - } else { - buff.WriteString(" NegPayload: \n") - } - buff.WriteString("}\n") - - return buff.String() -} - -func New(transport *tpkt.TPKT) *X224 { +func New(transport transport.Transport) *X224 { return &X224{ transport: transport, reqProtocol: PROTOCOL_SSL, } } -func (x *X224) Read() (int, []byte, error) { - return x.transport.Read() +/* 注册事件回调函数 */ + +// OnData 数据回调 +func (x *X224) OnData(handler func([]byte)) { + x.dataHandlers = append(x.dataHandlers, handler) } -func (x *X224) Write(data []byte) (int, error) { - return x.transport.Write(data) +// OnError 错误回调 +func (x *X224) OnError(handler func(error)) { + x.errorHandlers = append(x.errorHandlers, handler) } +// 从字节流中解析PDU头部 func (x *X224) parsePduHeader(reader *bytes.Reader, pdu *X224PDU) error { var err error @@ -186,29 +158,76 @@ func (x *X224) serialize(pdu *X224PDU) []byte { _ = binary.Write(buff, binary.LittleEndian, pdu.ClsOpt) // 仅连接相关PDU包含Cookie和协商负载 - if pdu.Type == X224_CONNECTION_REQUEST || pdu.Type == X224_CONNECTION_CONFIRM { + switch pdu.Type { + case X224_CONNECTION_REQUEST, X224_CONNECTION_CONFIRM: if len(pdu.Cookie) > 0 { _ = binary.Write(buff, binary.LittleEndian, pdu.Cookie) _ = binary.Write(buff, binary.LittleEndian, []byte{0x0D, 0x0A}) } - if pdu.NegPayload != nil { - _ = binary.Write(buff, binary.LittleEndian, pdu.NegPayload.Type) - _ = binary.Write(buff, binary.LittleEndian, pdu.NegPayload.Flags) - _ = binary.Write(buff, binary.LittleEndian, pdu.NegPayload.Length) - _ = binary.Write(buff, binary.LittleEndian, pdu.NegPayload.Payload) + if pdu.NegMsg != nil { + _ = binary.Write(buff, binary.LittleEndian, pdu.NegMsg.Type) + _ = binary.Write(buff, binary.LittleEndian, pdu.NegMsg.Flags) + _ = binary.Write(buff, binary.LittleEndian, pdu.NegMsg.Length) + _ = binary.Write(buff, binary.LittleEndian, pdu.NegMsg.Payload) } + case X224_DATA: + buff.Write(pdu.Payload) } return buff.Bytes() } +// 处理数据消息 +func (x *X224) handleData(reader *bytes.Reader) { + buff, err := io.ReadAll(reader) + if err != nil { + x.handleError(err) + return + } + + for _, handler := range x.dataHandlers { + handler(buff) + } +} + +// 处理错误消息 +func (x *X224) handleError(err error) { + log.Println(err.Error()) + + for _, handler := range x.errorHandlers { + handler(err) + } +} + +// Write 发送数据消息 +func (x *X224) Write(data []byte) { + // 构造pdu + reqPdu := &X224PDU{ + Len: X224_HEADER_LENGTH + uint8(len(data)), // 头部长度 + 数据长度 + Type: X224_DATA, + DstRef: 0x00, + SrcRef: 0x00, + ClsOpt: 0x0, + Payload: data, + } + + // 序列化pdu + payload := x.serialize(reqPdu) + + // 写入传输层 + _, err := x.transport.Write(payload) + if err != nil { + x.handleError(err) + } +} + /* X224客户端相关实现 */ // ConnectToServer 客户端向服务端发起连接请求 -func (x *X224) ConnectToServer() error { +func (x *X224) ConnectToServer() { cookie := []byte("Cookie: mstshash=yv1ing") /* 构造pdu */ @@ -219,7 +238,7 @@ func (x *X224) ConnectToServer() error { SrcRef: 0x00, ClsOpt: 0x0, Cookie: cookie, - NegPayload: &Negotiation{ + NegMsg: &Negotiation{ Type: RDP_NEG_RSP, Flags: 0x00, Length: RDP_NEG_LENGTH, @@ -231,46 +250,57 @@ func (x *X224) ConnectToServer() error { payload := x.serialize(reqPdu) /* 写入传输层 */ - _, err := x.Write(payload) + _, err := x.transport.Write(payload) if err != nil { - return err + x.handleError(errors.New("[X224] failed to write pdu: " + err.Error())) + return } /* 等待处理服务端对连接请求的响应 */ - return x.handleConnectionConfirm() + go x.clientHandleServerMessage() +} + +// 客户端处理服务端的消息 +func (x *X224) clientHandleServerMessage() { + for { + li, packet, err := x.transport.Read() + if err != nil { + continue + } + + if li < 0x07 { + x.handleError(errors.New("[X224] invalid packet")) + return + } + + resPdu := &X224PDU{} + reader := bytes.NewReader(packet) + + err = x.parsePduHeader(reader, resPdu) + if err != nil { + x.handleError(errors.New("[X224] failed to parse pdu header: " + err.Error())) + return + } + + switch resPdu.Type { + case X224_CONNECTION_CONFIRM: + x.clientHandleConnectionConfirm(resPdu, reader) + case X224_DATA: + x.handleData(reader) + } + } } // handleConnectionConfirm 客户端处理服务端对连接请求的响应 -func (x *X224) handleConnectionConfirm() error { - li, packet, err := x.Read() - if err != nil { - return err - } - - if li < 0x07 { - return errors.New("[X224] invalid packet") - } - - resPdu := &X224PDU{} - reader := bytes.NewReader(packet) - - err = x.parsePduHeader(reader, resPdu) - if err != nil { - return errors.New("[X224] failed to parse pdu header: " + err.Error()) - } - +func (x *X224) clientHandleConnectionConfirm(resPdu *X224PDU, reader *bytes.Reader) { // 读取安全协议协商结果 neg := &Negotiation{} - err = neg.parseNegotiation(reader) + err := neg.parseNegotiation(reader) if err != nil { - return err + x.handleError(errors.New("[X224] failed to parse negotiation: " + err.Error())) } - resPdu.NegPayload = neg - - /* 完成安全协议协商 */ - fmt.Printf("client received server's confirm: \n%+v\n", resPdu.String()) - return nil + resPdu.NegMsg = neg } /* @@ -278,7 +308,7 @@ func (x *X224) handleConnectionConfirm() error { */ // 服务端向客户端发送响应 -func (x *X224) responseToClient(reqPdu *X224PDU) error { +func (x *X224) serverResponseToClient(reqPdu *X224PDU) { var err error // 构造协商响应 @@ -288,7 +318,7 @@ func (x *X224) responseToClient(reqPdu *X224PDU) error { DstRef: reqPdu.SrcRef, SrcRef: reqPdu.DstRef, ClsOpt: reqPdu.ClsOpt, - NegPayload: &Negotiation{ + NegMsg: &Negotiation{ Type: RDP_NEG_RSP, Flags: 0x00, Length: RDP_NEG_LENGTH, @@ -297,32 +327,48 @@ func (x *X224) responseToClient(reqPdu *X224PDU) error { } payload := x.serialize(resPdu) - _, err = x.Write(payload) + _, err = x.transport.Write(payload) if err != nil { - return errors.New("[X224] failed to write response: " + err.Error()) + x.handleError(errors.New("[X224] failed to write response: " + err.Error())) } +} - return nil +// 服务端处理客户端消息 +func (x *X224) serverHandleClientMessage() { + for { + _, packet, err := x.transport.Read() + if err != nil { + continue + } + + reqPdu := &X224PDU{} + reader := bytes.NewReader(packet) + + err = x.parsePduHeader(reader, reqPdu) + if err != nil { + x.handleError(errors.New("[X224] failed to parse pdu header: " + err.Error())) + return + } + + switch reqPdu.Type { + case X224_CONNECTION_REQUEST: + x.serverHandleConnectionRequest(reqPdu, reader) + case X224_DATA: + x.handleData(reader) + } + } } // 服务端处理客户端发来的连接请求 -func (x *X224) handleConnectionRequest(packet []byte) error { - var err error - - reqPdu := &X224PDU{} - reader := bytes.NewReader(packet) - - err = x.parsePduHeader(reader, reqPdu) - if err != nil { - return err - } +func (x *X224) serverHandleConnectionRequest(reqPdu *X224PDU, reader *bytes.Reader) { // 解析Cookie cookieBuff := make([]byte, 0, 32) for { b, err := reader.ReadByte() if err != nil { - return errors.New("[X224] failed to read cookie: " + err.Error()) + x.handleError(errors.New("[X224] failed to read cookie: " + err.Error())) + return } cookieBuff = append(cookieBuff, b) if len(cookieBuff) >= 2 && bytes.Equal(cookieBuff[len(cookieBuff)-2:], []byte{0x0D, 0x0A}) { @@ -334,16 +380,14 @@ func (x *X224) handleConnectionRequest(packet []byte) error { // 解析协商请求 reqNeg := &Negotiation{} if err := reqNeg.parseNegotiation(reader); err != nil { - return errors.New("[X224] failed to parse negotiation: " + err.Error()) + x.handleError(errors.New("[X224] failed to parse negotiation: " + err.Error())) + return } - reqPdu.NegPayload = reqNeg + reqPdu.NegMsg = reqNeg // 确定使用协议 x.selProtocol = PROTOCOL_SSL - /* 完成安全协议协商 */ - fmt.Printf("server received client's request: \n%+v\n", reqPdu.String()) - // 响应请求 - return x.responseToClient(reqPdu) + x.serverResponseToClient(reqPdu) } diff --git a/protocol/x224/x224_test.go b/protocol/x224/x224_test.go index 0296f7d..1ad13cb 100644 --- a/protocol/x224/x224_test.go +++ b/protocol/x224/x224_test.go @@ -1,9 +1,11 @@ package x224 import ( + "fmt" "net" "rdp_channel/protocol/tpkt" "testing" + "time" ) func TestX224(t *testing.T) { @@ -27,18 +29,15 @@ func runServer(t *testing.T) { go func(conn net.Conn) { defer conn.Close() - transport := tpkt.New(conn) - x224 := New(transport) + tpkt := tpkt.New(conn) + x224 := New(tpkt) - _, packet, err := x224.Read() - if err != nil { - t.Fatal(err) - } + x224.OnData(func(bytes []byte) { + fmt.Printf("server received: %s\n", string(bytes)) + x224.Write([]byte("yes! server hear!")) + }) - err = x224.handleConnectionRequest(packet) - if err != nil { - t.Fatal(err) - } + x224.serverHandleClientMessage() }(conn) } } @@ -50,11 +49,16 @@ func runClient(t *testing.T) { } defer conn.Close() - transport := tpkt.New(conn) - x224 := New(transport) + tpkt := tpkt.New(conn) + x224 := New(tpkt) - err = x224.ConnectToServer() - if err != nil { - t.Fatal(err) + x224.ConnectToServer() + x224.OnData(func(bytes []byte) { + fmt.Printf("client received: %s\n", string(bytes)) + }) + + for { + time.Sleep(1 * time.Second) + x224.Write([]byte("this is client!")) } }