增加SSH连接。

This commit is contained in:
dengqichen 2025-12-05 18:02:26 +08:00
parent 86276b2ffd
commit d4eb907536
2 changed files with 688 additions and 111 deletions

View File

@ -2,6 +2,7 @@ package com.qqchen.deploy.backend.deploy.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.core.task.AsyncTaskExecutor;
@ -12,102 +13,66 @@ import java.util.concurrent.ThreadPoolExecutor;
@EnableAsync
public class ThreadPoolConfig {
/**
* Jenkins任务同步线程池 - 使用虚拟线程Java 21+
*
* 为什么使用虚拟线程
* 1. Jenkins API调用是典型的**网络I/O密集型**任务
* 2. 等待Jenkins响应时线程会长时间阻塞
* 3. 虚拟线程在阻塞时不占用OS线程资源消耗极低
* 4. 支持数百个并发Jenkins构建同步
*/
@Bean("jenkinsTaskExecutor")
public ThreadPoolTaskExecutor jenkinsTaskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
// 核心线程数CPU核心数 + 1
executor.setCorePoolSize(Runtime.getRuntime().availableProcessors() + 1);
// 最大线程数CPU核心数 * 2
executor.setMaxPoolSize(Runtime.getRuntime().availableProcessors() * 2);
// 队列容量根据平均任务执行时间和期望响应时间来设置
executor.setQueueCapacity(50);
// 线程名前缀
executor.setThreadNamePrefix("jenkins-sync-");
// 线程空闲时间超过核心线程数的线程在空闲60秒后会被销毁
executor.setKeepAliveSeconds(60);
// 拒绝策略由调用线程处理
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
// 等待所有任务完成再关闭线程池
executor.setWaitForTasksToCompleteOnShutdown(true);
// 等待时间
executor.setAwaitTerminationSeconds(60);
executor.initialize();
public SimpleAsyncTaskExecutor jenkinsTaskExecutor() {
SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor("jenkins-virtual-");
executor.setVirtualThreads(true);
executor.setConcurrencyLimit(-1); // 无限制
return executor;
}
/**
* 仓库项目同步线程池 - 使用虚拟线程Java 21+
*
* 为什么使用虚拟线程
* 1. Git操作clone/fetch/pull**I/O密集型**任务
* 2. 网络I/O从远程仓库拉取代码+ 磁盘I/O写入本地
* 3. 虚拟线程支持大量并发仓库同步无线程池限制
*/
@Bean("repositoryProjectExecutor")
public ThreadPoolTaskExecutor repositoryProjectExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
// 核心线程数CPU核心数 * 2
executor.setCorePoolSize(Runtime.getRuntime().availableProcessors() * 2);
// 最大线程数CPU核心数 * 4
executor.setMaxPoolSize(Runtime.getRuntime().availableProcessors() * 4);
// 队列容量根据平均任务执行时间和期望响应时间来设置
executor.setQueueCapacity(100);
// 线程名前缀
executor.setThreadNamePrefix("repository-project-sync-");
// 线程空闲时间超过核心线程数的线程在空闲60秒后会被销毁
executor.setKeepAliveSeconds(60);
// 拒绝策略由调用线程处理
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
// 等待所有任务完成再关闭线程池
executor.setWaitForTasksToCompleteOnShutdown(true);
// 等待时间
executor.setAwaitTerminationSeconds(60);
executor.initialize();
public SimpleAsyncTaskExecutor repositoryProjectExecutor() {
SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor("repo-project-virtual-");
executor.setVirtualThreads(true);
executor.setConcurrencyLimit(-1); // 无限制
return executor;
}
/**
* 仓库分支同步线程池 - 使用虚拟线程Java 21+
*
* 为什么使用虚拟线程
* 1. Git分支操作checkout/merge/rebase**I/O密集型**任务
* 2. 大量磁盘I/O读取/写入Git对象
* 3. 虚拟线程支持数百个并发分支同步
*/
@Bean("repositoryBranchExecutor")
public ThreadPoolTaskExecutor repositoryBranchExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
// 核心线程数CPU核心数 * 2
executor.setCorePoolSize(Runtime.getRuntime().availableProcessors() * 2);
// 最大线程数CPU核心数 * 4
executor.setMaxPoolSize(Runtime.getRuntime().availableProcessors() * 4);
// 队列容量根据平均任务执行时间和期望响应时间来设置
executor.setQueueCapacity(100);
// 线程名前缀
executor.setThreadNamePrefix("repository-branch-sync-");
// 线程空闲时间超过核心线程数的线程在空闲60秒后会被销毁
executor.setKeepAliveSeconds(60);
// 拒绝策略由调用线程处理
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
// 等待所有任务完成再关闭线程池
executor.setWaitForTasksToCompleteOnShutdown(true);
// 等待时间
executor.setAwaitTerminationSeconds(60);
executor.initialize();
public SimpleAsyncTaskExecutor repositoryBranchExecutor() {
SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor("repo-branch-virtual-");
executor.setVirtualThreads(true);
executor.setConcurrencyLimit(-1); // 无限制
return executor;
}
/**
* 通用应用任务线程池 - 保留平台线程不使用虚拟线程
*
* 为什么不使用虚拟线程
* 1. 通用线程池用途未知可能包含**CPU密集型**任务
* 2. CPU密集型任务使用虚拟线程反而会降低性能线程调度开销
* 3. 虚拟线程适合I/O密集型不适合计算密集型
* 4. 平台线程对CPU密集型任务更高效
*
* 💡 如果确认只用于I/O密集型任务可改为虚拟线程
*/
@Bean("applicationTaskExecutor")
public AsyncTaskExecutor applicationTaskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
@ -124,38 +89,40 @@ public class ThreadPoolConfig {
}
/**
* SSH输出监听线程池
* 用于异步读取SSH输出流并推送到WebSocket
* SSH输出监听线程池 - 使用虚拟线程Java 21+
*
* 为什么使用虚拟线程
* 1. SSH输出监听是典型的**阻塞I/O密集型**任务
* 2. 每个SSH连接需要2个长期阻塞的线程stdout + stderr
* 3. 虚拟线程几乎无资源开销支持数百万并发
* 4. 完美适配大量SSH长连接场景
*
* 📊 性能对比
* - 平台线程50个SSH连接 = 100个线程 100-200MB内存
* - 虚拟线程50个SSH连接 = 100个虚拟线程 几MB内存
*
* 💡 方案选择
* - 方案1当前SimpleAsyncTaskExecutor - Spring集成支持优雅关闭可定制线程名
* - 方案2Executors.newVirtualThreadPerTaskExecutor() - 原生API最简洁性能略优
*/
@Bean("sshOutputExecutor")
public ThreadPoolTaskExecutor sshOutputExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
public org.springframework.core.task.SimpleAsyncTaskExecutor sshOutputExecutor() {
// 方案1Spring封装的虚拟线程Executor推荐
// 优点与Spring集成支持优雅关闭线程名可定制便于调试
org.springframework.core.task.SimpleAsyncTaskExecutor executor =
new org.springframework.core.task.SimpleAsyncTaskExecutor("ssh-virtual-");
// 核心线程数预期同时活跃的SSH连接数
executor.setCorePoolSize(10);
// 关键启用虚拟线程Java 21+
executor.setVirtualThreads(true);
// 最大线程数支持的最大SSH连接数
executor.setMaxPoolSize(50);
// 并发限制-1表示无限制虚拟线程资源消耗极低
executor.setConcurrencyLimit(-1);
// 队列容量等待处理的SSH输出监听任务
executor.setQueueCapacity(100);
// 线程名前缀
executor.setThreadNamePrefix("ssh-output-");
// 线程空闲时间SSH会话关闭后线程60秒后回收
executor.setKeepAliveSeconds(60);
// 拒绝策略由调用线程处理确保SSH输出不会丢失
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
// 等待所有任务完成再关闭线程池
executor.setWaitForTasksToCompleteOnShutdown(true);
// 等待时间
executor.setAwaitTerminationSeconds(30);
executor.initialize();
return executor;
// 方案2原生虚拟线程Executor可选
// 如果需要纯Java实现无Spring依赖可以使用
// return Executors.newVirtualThreadPerTaskExecutor();
// 注意需要手动管理生命周期线程名为 VirtualThread-#1
}
}

View File

@ -0,0 +1,610 @@
package com.qqchen.deploy.backend.deploy.handler;
import com.qqchen.deploy.backend.deploy.dto.SSHMessage;
import com.qqchen.deploy.backend.deploy.entity.Server;
import com.qqchen.deploy.backend.deploy.enums.AuthTypeEnum;
import com.qqchen.deploy.backend.deploy.enums.SSHMessageTypeEnum;
import com.qqchen.deploy.backend.deploy.enums.SSHStatusEnum;
import com.qqchen.deploy.backend.deploy.service.ISSHAuditLogService;
import com.qqchen.deploy.backend.deploy.service.IServerService;
import com.qqchen.deploy.backend.framework.enums.ResponseCode;
import com.qqchen.deploy.backend.framework.exception.BusinessException;
import com.qqchen.deploy.backend.framework.utils.JsonUtils;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.connection.channel.direct.Session;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import net.schmizz.sshj.userauth.keyprovider.KeyProvider;
import net.schmizz.sshj.userauth.password.PasswordUtils;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Future;
/**
* Server SSH WebSocket处理器
* 处理Web SSH终端的WebSocket连接和SSH交互
*/
@Slf4j
@Component
public class ServerSSHWebSocketHandler extends TextWebSocketHandler {
@Resource
private IServerService serverService;
@Resource
private ISSHAuditLogService auditLogService;
@Resource(name = "sshOutputExecutor")
private AsyncTaskExecutor sshOutputExecutor;
/**
* 最大并发SSH会话数每个用户
*/
private static final int MAX_SESSIONS_PER_USER = 5;
/**
* WebSocket会话存储sessionId -> WebSocketSession
*/
private final Map<String, WebSocketSession> webSocketSessions = new ConcurrentHashMap<>();
/**
* SSH会话存储sessionId -> SSHClient
*/
private final Map<String, SSHClient> sshClients = new ConcurrentHashMap<>();
/**
* SSH会话通道存储sessionId -> Session.Shell
*/
private final Map<String, Session.Shell> sshShells = new ConcurrentHashMap<>();
/**
* 输出监听任务存储sessionId -> Future
*/
private final Map<String, Future<?>> outputTasks = new ConcurrentHashMap<>();
/**
* WebSocket连接建立时触发
*/
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
String sessionId = session.getId();
log.info("WebSocket连接建立: sessionId={}", sessionId);
try {
// 1. 从attributes中获取用户信息由认证拦截器设置
Long userId = (Long) session.getAttributes().get("userId");
String username = (String) session.getAttributes().get("username");
String clientIp = (String) session.getAttributes().get("clientIp");
String userAgent = (String) session.getAttributes().get("userAgent");
if (userId == null) {
log.error("无法获取用户信息: sessionId={}", sessionId);
sendError(session, "认证失败");
session.close(CloseStatus.POLICY_VIOLATION);
return;
}
// 2. 从URL中提取serverId
Long serverId = extractServerId(session);
if (serverId == null) {
sendError(session, "无效的服务器ID");
session.close(CloseStatus.BAD_DATA);
return;
}
// 3. 获取服务器信息
Server server = serverService.findEntityById(serverId);
if (server == null) {
sendError(session, "服务器不存在: " + serverId);
session.close(CloseStatus.NOT_ACCEPTABLE);
return;
}
// 4. 检查用户对该服务器的SSH会话数
long activeSessions = auditLogService.countUserActiveSessionsForServer(userId, serverId);
log.info("用户当前对该服务器的SSH连接数: userId={}, serverId={}, serverName={}, current={}, max={}",
userId, serverId, server.getServerName(), activeSessions, MAX_SESSIONS_PER_USER);
if (activeSessions >= MAX_SESSIONS_PER_USER) {
log.warn("用户对该服务器的SSH会话数超过限制: userId={}, serverId={}, serverName={}, current={}, max={}",
userId, serverId, server.getServerName(), activeSessions, MAX_SESSIONS_PER_USER);
sendError(session, "对服务器【" + server.getServerName() + "】的SSH连接数超过限制最多" + MAX_SESSIONS_PER_USER + "个)");
session.close(CloseStatus.POLICY_VIOLATION);
return;
}
// 5. 权限校验预留实际项目中需要实现
// TODO: 根据业务需求实现权限校验逻辑
// 例如检查用户是否是管理员或者服务器是否允许该用户访问
// 6. 发送连接中状态
sendStatus(session, SSHStatusEnum.CONNECTING);
// 7. 建立SSH连接
SSHClient sshClient = createSSHConnection(server);
sshClients.put(sessionId, sshClient);
// 8. 打开Shell通道并分配PTY伪终端
Session sshSession = sshClient.startSession();
// 关键分配PTY启用交互式Shell回显提示符
// 参数终端类型, 列数, 行数, 宽度(像素), 高度(像素), 终端模式
sshSession.allocatePTY("xterm", 80, 24, 0, 0, java.util.Collections.emptyMap());
log.debug("PTY已分配: sessionId={}, termType=xterm, cols=80, rows=24", sessionId);
Session.Shell shell = sshSession.startShell();
log.debug("Shell已启动: sessionId={}", sessionId);
// 保存会话信息
webSocketSessions.put(sessionId, session);
sshShells.put(sessionId, shell);
// 9. 优化先启动输出监听线程确保不错过任何SSH输出
Future<?> stdoutTask = sshOutputExecutor.submit(() -> readSSHOutput(session, shell));
outputTasks.put(sessionId, stdoutTask);
// 同时启动错误流监听某些SSH服务器会将输出发送到错误流
Future<?> stderrTask = sshOutputExecutor.submit(() -> readSSHError(session, shell));
outputTasks.put(sessionId + "_stderr", stderrTask);
// 10. 发送连接成功状态
sendStatus(session, SSHStatusEnum.CONNECTED);
log.info("SSH连接建立成功: sessionId={}, userId={}, username={}, server={}@{}",
sessionId, userId, username, server.getSshUser(), server.getHostIp());
// 11. 异步创建审计日志不阻塞主线程
// 使用CompletableFuture异步执行避免数据库操作延迟影响SSH输出接收
CompletableFuture.runAsync(() -> {
try {
Long auditLogId = auditLogService.createAuditLog(userId, server, sessionId, clientIp, userAgent);
session.getAttributes().put("auditLogId", auditLogId);
log.info("SSH审计日志已创建: auditLogId={}, sessionId={}", auditLogId, sessionId);
} catch (Exception e) {
log.error("创建SSH审计日志失败: sessionId={}", sessionId, e);
}
});
} catch (Exception e) {
log.error("建立SSH连接失败: sessionId={}", sessionId, e);
sendError(session, "连接失败: " + e.getMessage());
// 记录失败的审计日志
try {
// 异步场景直接尝试创建审计日志有锁保护已存在则直接返回
// 无需检查 attributes因为异步任务可能还未完成
Long userId = (Long) session.getAttributes().get("userId");
String clientIp = (String) session.getAttributes().get("clientIp");
String userAgent = (String) session.getAttributes().get("userAgent");
Long serverId = extractServerId(session);
if (userId != null && serverId != null) {
Server server = serverService.findEntityById(serverId);
if (server != null) {
// 先创建如果已存在则返回已有ID有锁保护不会重复
Long auditLogId = auditLogService.createAuditLog(userId, server, sessionId, clientIp, userAgent);
session.getAttributes().put("auditLogId", auditLogId);
// 再关闭
auditLogService.closeAuditLog(sessionId, "FAILED", e.getMessage());
}
}
} catch (Exception auditEx) {
log.error("记录失败审计日志异常", auditEx);
}
cleanupSession(sessionId);
session.close(CloseStatus.SERVER_ERROR);
}
}
/**
* 接收前端消息
*/
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
String sessionId = session.getId();
try {
// 解析消息
SSHMessage sshMessage = JsonUtils.fromJson(message.getPayload(), SSHMessage.class);
if (sshMessage == null) {
log.warn("解析消息失败: sessionId={}", sessionId);
return;
}
if (sshMessage.getType() != SSHMessageTypeEnum.INPUT) {
log.warn("收到非input类型消息: sessionId={}, type={}", sessionId, sshMessage.getType());
return;
}
// 获取SSH Shell
Session.Shell shell = sshShells.get(sessionId);
if (shell == null) {
sendError(session, "SSH连接未建立");
return;
}
// 发送命令到SSH
String input = sshMessage.getData();
if (input != null) {
OutputStream outputStream = shell.getOutputStream();
outputStream.write(input.getBytes(StandardCharsets.UTF_8));
outputStream.flush();
// 记录命令到审计日志只记录有意义的命令过滤掉单个字符的按键
if (input.length() > 0) {
auditLogService.recordCommand(sessionId, input);
}
log.debug("发送命令到SSH: sessionId={}, length={}", sessionId, input.length());
}
} catch (Exception e) {
log.error("处理WebSocket消息失败: sessionId={}", sessionId, e);
sendError(session, "命令执行失败: " + e.getMessage());
}
}
/**
* WebSocket连接关闭时触发
*/
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
String sessionId = session.getId();
log.info("WebSocket连接关闭: sessionId={}, status={}", sessionId, status);
// 关闭审计日志
try {
String auditStatus = status.getCode() == CloseStatus.NORMAL.getCode() ? "SUCCESS" : "INTERRUPTED";
auditLogService.closeAuditLog(sessionId, auditStatus, status.getReason());
} catch (Exception e) {
log.error("关闭审计日志失败: sessionId={}", sessionId, e);
}
// 清理资源
cleanupSession(sessionId);
}
/**
* WebSocket传输错误时触发
*/
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
String sessionId = session.getId();
log.error("WebSocket传输错误: sessionId={}", sessionId, exception);
// 记录错误到审计日志
try {
auditLogService.closeAuditLog(sessionId, "FAILED", "传输错误: " + exception.getMessage());
} catch (Exception e) {
log.error("关闭审计日志失败: sessionId={}", sessionId, e);
}
sendError(session, "传输错误: " + exception.getMessage());
cleanupSession(sessionId);
session.close(CloseStatus.SERVER_ERROR);
}
/**
* 创建SSH连接
*/
private SSHClient createSSHConnection(Server server) throws IOException {
SSHClient sshClient = new SSHClient();
// 跳过主机密钥验证生产环境建议使用正式的验证方式
sshClient.addHostKeyVerifier(new PromiscuousVerifier());
// 设置超时
sshClient.setTimeout(30000);
sshClient.setConnectTimeout(30000);
// 连接服务器
sshClient.connect(server.getHostIp(), server.getSshPort());
// 认证
if (server.getAuthType() == AuthTypeEnum.PASSWORD) {
// 密码认证
sshClient.authPassword(server.getSshUser(), server.getSshPassword());
} else if (server.getAuthType() == AuthTypeEnum.KEY) {
// 密钥认证
KeyProvider keyProvider;
if (server.getSshPassphrase() != null && !server.getSshPassphrase().isEmpty()) {
keyProvider = sshClient.loadKeys(server.getSshPrivateKey(), null,
PasswordUtils.createOneOff(server.getSshPassphrase().toCharArray()));
} else {
keyProvider = sshClient.loadKeys(server.getSshPrivateKey(), null, null);
}
sshClient.authPublickey(server.getSshUser(), keyProvider);
} else {
throw new BusinessException(ResponseCode.INVALID_PARAM, new Object[]{"不支持的认证类型: " + server.getAuthType()});
}
return sshClient;
}
/**
* 读取SSH输出并发送到前端
*/
private void readSSHOutput(WebSocketSession session, Session.Shell shell) {
String sessionId = session.getId();
log.debug("开始监听SSH输出: sessionId={}", sessionId);
try {
InputStream inputStream = shell.getInputStream();
byte[] buffer = new byte[1024];
int len;
log.debug("SSH输出流已获取开始循环读取: sessionId={}", sessionId);
while (session.isOpen() && (len = inputStream.read(buffer)) > 0) {
String output = new String(buffer, 0, len, StandardCharsets.UTF_8);
log.debug("读取到SSH输出: sessionId={}, length={}, content={}",
sessionId, len, output.replaceAll("\\r", "\\\\r").replaceAll("\\n", "\\\\n"));
sendOutput(session, output);
log.debug("SSH输出已发送到前端: sessionId={}", sessionId);
}
log.debug("SSH输出监听结束: sessionId={}, session.isOpen={}", sessionId, session.isOpen());
} catch (java.io.InterruptedIOException e) {
// 线程被中断正常的清理过程检查是否是WebSocket关闭导致的
if (!session.isOpen()) {
log.debug("SSH输出监听线程被正常中断WebSocket已关闭: sessionId={}", sessionId);
} else {
log.error("SSH输出监听线程被异常中断: sessionId={}", sessionId, e);
// 只在session仍然打开时尝试发送错误消息
try {
sendError(session, "SSH连接被中断");
} catch (Exception ex) {
log.debug("发送错误消息失败session可能已关闭: sessionId={}", sessionId);
}
}
} catch (IOException e) {
// 其他IO异常真正的错误
if (session.isOpen()) {
log.error("读取SSH输出失败: sessionId={}", sessionId, e);
try {
sendError(session, "读取SSH输出失败: " + e.getMessage());
} catch (Exception ex) {
log.debug("发送错误消息失败session可能已关闭: sessionId={}", sessionId);
}
} else {
log.debug("读取SSH输出时发生IO异常但session已关闭正常: sessionId={}", sessionId);
}
}
}
/**
* 读取SSH错误流并发送到前端
* 某些SSH服务器可能将输出发送到标准错误流
*/
private void readSSHError(WebSocketSession session, Session.Shell shell) {
String sessionId = session.getId();
log.debug("开始监听SSH错误流: sessionId={}", sessionId);
try {
InputStream errorStream = shell.getErrorStream();
byte[] buffer = new byte[1024];
int len;
log.debug("SSH错误流已获取开始循环读取: sessionId={}", sessionId);
while (session.isOpen() && (len = errorStream.read(buffer)) > 0) {
String output = new String(buffer, 0, len, StandardCharsets.UTF_8);
log.debug("读取到SSH错误流输出: sessionId={}, length={}, content={}",
sessionId, len, output.replaceAll("\\r", "\\\\r").replaceAll("\\n", "\\\\n"));
sendOutput(session, output); // 错误流也作为output发送到前端
log.debug("SSH错误流输出已发送到前端: sessionId={}", sessionId);
}
log.debug("SSH错误流监听结束: sessionId={}", sessionId);
} catch (java.io.InterruptedIOException e) {
if (!session.isOpen()) {
log.debug("SSH错误流监听线程被正常中断WebSocket已关闭: sessionId={}", sessionId);
} else {
log.error("SSH错误流监听线程被异常中断: sessionId={}", sessionId, e);
}
} catch (IOException e) {
if (session.isOpen()) {
log.error("读取SSH错误流失败: sessionId={}", sessionId, e);
} else {
log.debug("读取SSH错误流时发生IO异常但session已关闭正常: sessionId={}", sessionId);
}
}
}
/**
* 清理会话资源
*/
private void cleanupSession(String sessionId) {
log.debug("清理会话资源: sessionId={}", sessionId);
// 移除WebSocketSession
webSocketSessions.remove(sessionId);
// 取消输出监听任务标准输出
Future<?> stdoutTask = outputTasks.remove(sessionId);
if (stdoutTask != null && !stdoutTask.isDone()) {
stdoutTask.cancel(true);
}
// 取消错误流监听任务
Future<?> stderrTask = outputTasks.remove(sessionId + "_stderr");
if (stderrTask != null && !stderrTask.isDone()) {
stderrTask.cancel(true);
}
// 关闭SSH Shell
Session.Shell shell = sshShells.remove(sessionId);
if (shell != null) {
try {
shell.close();
} catch (IOException e) {
log.warn("关闭SSH Shell失败: sessionId={}", sessionId, e);
}
}
// 关闭SSH连接
SSHClient sshClient = sshClients.remove(sessionId);
if (sshClient != null) {
try {
sshClient.disconnect();
} catch (IOException e) {
log.warn("关闭SSH连接失败: sessionId={}", sessionId, e);
}
}
}
/**
* 从WebSocket session URL中提取serverId
*/
private Long extractServerId(WebSocketSession session) {
try {
String path = session.getUri().getPath();
// /api/v1/server-ssh/connect/{serverId}
String[] parts = path.split("/");
if (parts.length > 0) {
return Long.parseLong(parts[parts.length - 1]);
}
} catch (Exception e) {
log.error("提取serverId失败", e);
}
return null;
}
/**
* 发送output类型消息到前端
*/
private void sendOutput(WebSocketSession session, String output) throws IOException {
if (!session.isOpen()) {
return; // session已关闭直接返回
}
SSHMessage message = SSHMessage.output(output);
String json = JsonUtils.toJson(message);
if (json != null) {
session.sendMessage(new TextMessage(json));
}
}
/**
* 发送error类型消息到前端
*/
private void sendError(WebSocketSession session, String errorMessage) throws IOException {
if (!session.isOpen()) {
return; // session已关闭直接返回
}
SSHMessage message = SSHMessage.error(errorMessage);
String json = JsonUtils.toJson(message);
if (json != null) {
session.sendMessage(new TextMessage(json));
}
}
/**
* 发送status类型消息到前端
*/
private void sendStatus(WebSocketSession session, SSHStatusEnum status) throws IOException {
if (!session.isOpen()) {
return; // session已关闭直接返回
}
SSHMessage message = SSHMessage.status(status);
String json = JsonUtils.toJson(message);
if (json != null) {
session.sendMessage(new TextMessage(json));
}
}
/**
* 优雅下线应用关闭时清理所有活跃的SSH会话
* 使用 @PreDestroy 注解确保在Spring容器销毁前执行
*/
@jakarta.annotation.PreDestroy
public void gracefulShutdown() {
log.warn("====== 应用准备关闭开始优雅下线所有SSH会话 ======");
log.warn("当前活跃SSH会话数: {}", webSocketSessions.size());
if (webSocketSessions.isEmpty()) {
log.info("没有活跃的SSH会话跳过优雅下线");
return;
}
// 记录开始时间
long startTime = System.currentTimeMillis();
int successCount = 0;
int failureCount = 0;
// 遍历所有活跃会话
for (Map.Entry<String, WebSocketSession> entry : webSocketSessions.entrySet()) {
String sessionId = entry.getKey();
WebSocketSession session = entry.getValue();
try {
log.info("关闭SSH会话: sessionId={}", sessionId);
// 1. 尝试向前端发送服务器下线通知
try {
if (session.isOpen()) {
sendError(session, "服务器正在重启,连接即将关闭");
// 给前端一点时间接收消息
Thread.sleep(100);
}
} catch (Exception e) {
log.debug("发送下线通知失败: sessionId={}", sessionId, e);
}
// 2. 更新审计日志最重要防止僵尸会话
try {
auditLogService.closeAuditLog(sessionId, "SERVER_SHUTDOWN", "服务器优雅下线");
log.info("审计日志已更新: sessionId={}", sessionId);
} catch (Exception e) {
log.error("更新审计日志失败: sessionId={}", sessionId, e);
}
// 3. 清理资源
cleanupSession(sessionId);
// 4. 关闭WebSocket连接
try {
if (session.isOpen()) {
session.close(new CloseStatus(1001, "服务器正在重启"));
}
} catch (Exception e) {
log.debug("关闭WebSocket失败: sessionId={}", sessionId, e);
}
successCount++;
log.info("SSH会话关闭成功: sessionId={}", sessionId);
} catch (Exception e) {
failureCount++;
log.error("关闭SSH会话失败: sessionId={}", sessionId, e);
}
}
// 清空所有缓存
webSocketSessions.clear();
sshClients.clear();
sshShells.clear();
outputTasks.clear();
long duration = System.currentTimeMillis() - startTime;
log.warn("====== 优雅下线完成 ======");
log.warn("总会话数: {}, 成功: {}, 失败: {}, 耗时: {}ms",
successCount + failureCount, successCount, failureCount, duration);
}
}