郑永安
2023-06-19 7a6abd05683528032687c75e80e0bd2030a3e46c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
package com.gkhy.safePlatform.safeCheck.webSocket;
 
 
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
 
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
 
//@ServerEndpoint("/ws/test/{taskId}")
//@Component
public class SafeCheckWebSocketServer {
 
    private final Logger logger = LoggerFactory.getLogger(this.getClass());
 
    /**
     * 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的
     */
    private static int onlineCount = 0;
 
    /**
     * concurrent 包的线程安全Set,用来存放每个客户端对应的 myWebSocket对象
     * 根据taskId来获取对应的 WebSocket
     */
 
    private static ConcurrentHashMap<String,SafeCheckWebSocketServer> webSocketMap = new ConcurrentHashMap<>();
 
    /**
     * 与某个客户端的连接会话,需要通过它来给客户端发送数据
     */
    private Session session;
 
    /**
     * 接收 taskId 任务id
     */
    private String taskId = "";
 
 
    /**
     * 连接建立成功调用的方法
     *
     * @param session
     * @param taskId
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("taskId") String taskId) {
        this.session = session;
        this.taskId = taskId;
 
        webSocketMap.put(taskId, this);
        logger.info("webSocketMap -> " + JSON.toJSONString(webSocketMap));
 
        addOnlineCount(); // 在线数 +1
        logger.info("有新窗口开始监听【任务id】:" + taskId + ",当前在线人数为" + getOnlineCount());
 
        try {
            sendMessage("连接成功");
        } catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException();
        }
 
    }
 
    /**
     * 关闭连接
     */
 
    @OnClose
    public void onClose() {
        if (webSocketMap.get(this.taskId) != null) {
            webSocketMap.remove(this.taskId);
            subOnlineCount(); // 人数 -1
            logger.info("有一任务连接关闭,当前在线人数为:" + getOnlineCount());
        }
    }
 
    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     * @param session
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        logger.info("收到来自窗口" + taskId + "的信息:" + message);
 
        if (org.apache.commons.lang3.StringUtils.isNotBlank(message)) {
            try {
                // 解析发送的报文
                JSONObject jsonObject = JSON.parseObject(message);
                // 追加发送人(防窜改)
                jsonObject.put("fromTaskId", this.taskId);
                String toTaskId = jsonObject.getString("toTaskId");
                // 传送给对应 totaskId 用户的 WebSocket
                if (StringUtils.isNotBlank(toTaskId) && webSocketMap.containsKey(toTaskId)) {
                    webSocketMap.get(toTaskId).sendMessage(jsonObject.toJSONString());
                } else {
                    logger.info("请求的taskId:" + toTaskId + "不在该服务器上"); // 否则不在这个服务器上,发送到 MySQL 或者 Redis
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
 
    /**
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        logger.error("用户错误:" + this.taskId + ",原因:" + error.getMessage());
        error.printStackTrace();
    }
 
    /**
     * 实现服务器主动推送
     *
     * @param message
     * @throws IOException
     */
    public void sendMessage(String message) throws IOException {
        this.session.getBasicRemote().sendText(message);
    }
 
    /**
     * 群发自定义消息
     *
     * @param message
     * @param taskId
     * @throws IOException
     */
    public static void sendInfo(String message, @PathParam("taskId") String taskId) throws IOException {
        // 遍历集合,可设置为推送给指定sid,为 null 时发送给所有人
        Iterator entrys = webSocketMap.entrySet().iterator();
        while (entrys.hasNext()) {
            Map.Entry entry = (Map.Entry) entrys.next();
 
            if (taskId == null) {
                webSocketMap.get(entry.getKey()).sendMessage(message);
            } else if (entry.getKey().toString().contains(taskId)) {
                webSocketMap.get(entry.getKey()).sendMessage(message);
            }
        }
    }
 
    /**
     * @description 判断taskid是否成功连接上服务器
     */
    public static boolean taskIsconnectByTaskId(String taskId){
 
        boolean taskIsconnect = false;
 
        Iterator entrys = webSocketMap.entrySet().iterator();
        while (entrys.hasNext()) {
            Map.Entry entry = (Map.Entry) entrys.next();
            if (entry.getKey().toString().contains(taskId)){
                return true;
            }
        }
        return taskIsconnect;
    }
 
    private static synchronized int getOnlineCount() {
        return onlineCount;
    }
 
    private static synchronized void addOnlineCount() {
        SafeCheckWebSocketServer.onlineCount++;
    }
 
    private static synchronized void subOnlineCount() {
        SafeCheckWebSocketServer.onlineCount--;
    }
}