Java处理WebSocket协议

处理类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
public class Handshake {
private byte crlf13 = (byte) 13; // '\r'

private byte crlf10 = (byte) 10; // '\n'

private InputStream input;

Map http = new HashMap();

/**
*
* @param input
* 输入流
*/
public Handshake(InputStream input) {
this.input = input;
}

public byte[] getResponse() throws IOException {
StringBuffer header = new StringBuffer();
byte[] content = new byte[8];
byte[] crlf = new byte[1];
int crlfNum = 0; // 已经连接的回车换行数 crlfNum=4为头部结束

// 读取头部
while (input.read(crlf) != -1) {
if (crlf[0] == crlf13 || crlf[0] == crlf10) {
crlfNum++;
} else {
crlfNum = 0;
} // 不是则清
header.append(new String(crlf, 0, 1)); // byte数组相+
if (crlfNum == 4) {
input.read(content); // 读取内容
break;
}
}

String[] hhh = header.toString().split("\r\n");
http.put("Method", hhh[0].split(" ")[0]);
http.put("Path", hhh[0].split(" ")[1]);
http.put("Http-Protocol", hhh[0].split(" ")[2]);

http.put("Upgrade", header.substring(header.indexOf("Upgrade: ") + 9)
.split("\r\n")[0]);
http.put("Connection", header.substring(
header.indexOf("Connection: ") + 12).split("\r\n")[0]);
http.put("Host", header.substring(header.indexOf("Host: ") + 6).split(
"\r\n")[0]);
http.put("Origin", header.substring(header.indexOf("Origin: ") + 8 )
.split("\r\n")[0]);
http.put("Sec-WebSocket-Key1", header.substring(
header.indexOf("Sec-WebSocket-Key1: ") + 20).split("\r\n")[0]);
http.put("Sec-WebSocket-Key2", header.substring(
header.indexOf("Sec-WebSocket-Key2: ") + 20).split("\r\n")[0]);
http.put("Content", new String(content));

String key1 = http.get("Sec-WebSocket-Key1");
String key2 = http.get("Sec-WebSocket-Key2");
// 数字/空格数
long a = Long.parseLong(filterNonNumeric(key1))
/ filterNonSpace(key1).length();
long b = Long.parseLong(filterNonNumeric(key2))
/ filterNonSpace(key2).length();
// 转换为十六进制字符串
String ekey1 = Long.toHexString(a).toUpperCase();
String ekey2 = Long.toHexString(b).toUpperCase();
String ekey3 = bytes2HexStr(content);

// 补零
while (ekey1.length() < 8 )
ekey1 = "0" + ekey1;
while (ekey2.length() < 8 )
ekey2 = "0" + ekey2;
while (ekey3.length() < 8 )
ekey3 = "0" + ekey3;

byte[] bb = hexStr2Bytes(ekey1 + ekey2 + ekey3);
byte[] challenge = null;
try {
challenge = MessageDigest.getInstance("MD5").digest(bb);
} catch (NoSuchAlgorithmException e) {
e.printStackTrace();
}

StringBuffer sb = new StringBuffer();
sb.append("HTTP/1.1 101 WebSocket Protocol Handshake\r\n");
sb.append("Upgrade: WebSocket\r\n");
sb.append("Connection: Upgrade\r\n");
sb.append("Sec-WebSocket-Origin: " + http.get("Origin") + "\r\n");
sb.append("Sec-WebSocket-Location: ws://" + http.get("Host")
+ http.get("Path") + "\r\n\r\n");

return addByte(sb.toString().getBytes(), challenge);
}

public byte[] getMsg() throws IOException {
byte[] bbbb = null;
byte[] crlf = new byte[1];
while (input.read(crlf) != -1) {
// 处理任务
if (crlf[0] == (byte) 0) {
bbbb = new byte[] {};
}
bbbb = addByte(bbbb, crlf);
if (crlf[0] == (byte) 255) {
break;
}
}
return bbbb;
}
}

工具类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
public class Utils {

/**
* 数组相+
*
* @param target
* 目标数组
* @param src
* 加入数组
* @return 相加后的结果
*/
public static byte[] addByte(byte[] target, byte[] src) {
byte[] a = new byte[target.length + src.length];
for (int i = 0; i < target.length; i++) {
a[i] = target[i];
}
for (int j = target.length; j < a.length; j++) {
a[j] = src[j - target.length];
}
return a;
}

/**
* 过滤掉非数字的字符
* 例如:str="uis sdj13 e8 kj*ks90ao",则返回"13890"
*
* @param str
* @return 过滤后的字符串.如果str为空,则直接返回str
*/
public static String filterNonNumeric(String str) {
if (str == null || str == "") {
return str;
}
StringBuffer sb = new StringBuffer();
for (int i = 0; i < str.length(); i++) {
char c = str.charAt(i);
if (Character.isDigit(c)) {
sb.append(c);
}
}
return sb.toString();
}

/**
* 过滤掉非空格的字符
* 例如:str="uis sdj13 e8 kj*ks90ao",则返回" "
*
* @param str
* @return 过滤后的字符串.如果str为空,则直接返回str
*/
public static String filterNonSpace(String str) {
if (str == null || str == "") {
return str;
}
StringBuffer sb = new StringBuffer();
for (int i = 0; i < str.length(); i++) {
char c = str.charAt(i);
if (" ".equals(String.valueOf(c))) {
sb.append(c);
}
}
return sb.toString();
}

/**
* byte[]转换成十六进制字符串
* 例如:b=new byte[]{0,(byte) 255},则返回"00FF"
*
* @param b
* byte数组
* @return 大写十六进制字符串
*/
public static String bytes2HexStr(byte[] b) {
StringBuffer hs = new StringBuffer();
String stmp = "";
for (byte n : b) {
stmp = Integer.toHexString(n & 0XFF);
hs.append((stmp.length() == 1) ? "0" + stmp : stmp);
}
return hs.toString().toUpperCase();
}

/**
* 十六进制字符串转换成byte[]
* 例如:src="00FF",则返回new byte[]{0,(byte) 255}
*
* @param src
* 大写十六进制字符串
* @return byte数组
*/
public static byte[] hexStr2Bytes(String src) {
int l = src.length() / 2;
byte[] ret = new byte[l];
for (int i = 0; i < l; i++) {
ret[i] = (byte) Integer.parseInt(src.substring(i * 2, i * 2 + 2),
16);
}

return ret;
}
}

完成了协议分析的工作。