mirror of
https://github.com/yv1ing/rdp_channel.git
synced 2025-09-16 14:59:08 +08:00
基本完成TPKT协议
This commit is contained in:
86
protocol/tpkt/tpkt.go
Normal file
86
protocol/tpkt/tpkt.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package tpkt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// 协议常量
|
||||
const (
|
||||
TPKT_VERSION = 0x03
|
||||
TPKT_RESERVED = 0x00
|
||||
TPKT_HEADER_LENGTH = 0x04
|
||||
TPKT_MAX_PACKET_LENGTH = 0xffff
|
||||
)
|
||||
|
||||
var (
|
||||
TPKT_INVALID_VERSION = errors.New("[TPKT] invalid version")
|
||||
TPKT_INVALID_PACKET_LENGTH = errors.New("[TPKT] invalid packet length")
|
||||
)
|
||||
|
||||
// TPKT 协议封装
|
||||
type TPKT struct {
|
||||
connection net.Conn
|
||||
readBuff *bufio.Reader
|
||||
}
|
||||
|
||||
func New(conn net.Conn) *TPKT {
|
||||
tpkt := &TPKT{
|
||||
connection: conn,
|
||||
readBuff: bufio.NewReader(conn),
|
||||
}
|
||||
|
||||
return tpkt
|
||||
}
|
||||
|
||||
// Write 发送TPKT包
|
||||
func (tpkt *TPKT) Write(data []byte) (int, error) {
|
||||
dataLen := len(data)
|
||||
if dataLen > (TPKT_MAX_PACKET_LENGTH - TPKT_HEADER_LENGTH) {
|
||||
return 0, TPKT_INVALID_PACKET_LENGTH
|
||||
}
|
||||
|
||||
// TPKT封包
|
||||
packet := make([]byte, TPKT_HEADER_LENGTH+dataLen)
|
||||
packet[0] = TPKT_VERSION
|
||||
packet[1] = TPKT_RESERVED
|
||||
binary.BigEndian.PutUint16(packet[2:4], uint16(TPKT_HEADER_LENGTH+dataLen)) // 第2、3字节为tpkt包的长度
|
||||
|
||||
// 装填载荷
|
||||
copy(packet[TPKT_HEADER_LENGTH:], data)
|
||||
|
||||
// 发送TPKT包
|
||||
return tpkt.connection.Write(packet)
|
||||
}
|
||||
|
||||
func (tpkt *TPKT) Read() (int, []byte, error) {
|
||||
// 验证TPKT头
|
||||
header := make([]byte, TPKT_HEADER_LENGTH)
|
||||
_, err := io.ReadFull(tpkt.readBuff, header)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
// 验证TPKT版本
|
||||
if header[0] != TPKT_VERSION {
|
||||
return 0, nil, TPKT_INVALID_VERSION
|
||||
}
|
||||
|
||||
// 验证载荷长度
|
||||
length := binary.BigEndian.Uint16(header[2:4])
|
||||
if length > TPKT_MAX_PACKET_LENGTH {
|
||||
return 0, nil, TPKT_INVALID_PACKET_LENGTH
|
||||
}
|
||||
|
||||
// 读取载荷数据
|
||||
dataLen := length - TPKT_HEADER_LENGTH
|
||||
data := make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(tpkt.readBuff, data); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
return int(dataLen), data, nil
|
||||
}
|
||||
68
protocol/tpkt/tpkt_test.go
Normal file
68
protocol/tpkt/tpkt_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package tpkt
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTpkt(t *testing.T) {
|
||||
go runServer(t)
|
||||
runClient(t)
|
||||
}
|
||||
|
||||
func runServer(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:3388")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
t.Logf("tpkt server listening at %s\n", listener.Addr())
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
t.Logf("tpkt server accept error: %s\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf("tpkt server accepted connection from %s\n", conn.RemoteAddr())
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
tpkt := New(conn)
|
||||
dataLen, data, err := tpkt.Read()
|
||||
if err != nil {
|
||||
t.Logf("tpkt server read error: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("tpkt server read data(%d): %q\n", dataLen, data)
|
||||
|
||||
_, err = tpkt.Write([]byte("tpkt server hello"))
|
||||
if err != nil {
|
||||
t.Logf("tpkt server write error: %s\n", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func runClient(t *testing.T) {
|
||||
conn, err := net.Dial("tcp", "127.0.0.1:3388")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
tpkt := New(conn)
|
||||
|
||||
_, err = tpkt.Write([]byte("tpkt client hello"))
|
||||
if err != nil {
|
||||
t.Logf("tpkt client write error: %s\n", err)
|
||||
}
|
||||
|
||||
dataLen, data, err := tpkt.Read()
|
||||
if err != nil {
|
||||
t.Logf("tpkt client read error: %s\n", err)
|
||||
}
|
||||
t.Logf("tpkt client read data(%d): %q\n", dataLen, data)
|
||||
}
|
||||
Reference in New Issue
Block a user