#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
预定系统全栈更新脚本 - 支持多环境配置版
功能：
1. 从本地共享路径自动发现最新版本
2. 先备份再更新服务器上的前后端服务
3. 支持后端服务在指定容器中执行启动脚本
4. 异常安全控制 + 日志记录
"""

import os
import sys
import json
import argparse
import re
import fnmatch
from datetime import datetime
import tarfile
import tempfile
import shutil

import paramiko
from tqdm import tqdm


def load_config(env="test"):
    """
    加载 JSON 配置文件中的环境配置
    """
    config_path = os.path.join(os.path.dirname(__file__), "Server_Config.json")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"找不到配置文件: {config_path}")

    with open(config_path, "r", encoding="utf-8") as f:
        full_config = json.load(f)

    if env not in full_config:
        raise ValueError(f"配置中没有找到环境 '{env}'")

    return full_config[env]


class Deployer:
    def __init__(self, env_config):
        self.env_config = env_config
        self.ssh = paramiko.SSHClient()
        self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        self.sftp = None
        self.logger = self._setup_logger()

    def _setup_logger(self):
        def log(msg, level="INFO", important=False):
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            log_msg = f"[{timestamp}] [{level}] {msg}"
            if level in ("WARNING", "ERROR") or important:
                print(f"\033[1m{log_msg}\033[0m")
                if level == "ERROR":
                    sys.stderr.write(log_msg + "\n")
            else:
                print(log_msg)
        return log

    def connect(self):
        ssh_config = self.env_config['ssh']
        max_retries = 3
        for attempt in range(1, max_retries + 1):
            try:
                self.logger(f"尝试连接服务器 ({attempt}/{max_retries})...", important=True)
                self.ssh.connect(
                    hostname=ssh_config['host'],
                    port=ssh_config['port'],
                    username=ssh_config['user'],
                    password=ssh_config['password'],
                    timeout=30,
                    banner_timeout=200
                )
                self.sftp = self.ssh.open_sftp()
                self.logger("服务器连接成功", important=True)
                return True
            except Exception as e:
                self.logger(f"连接尝试 {attempt} 失败: {str(e)}", "WARNING")
                if attempt == max_retries:
                    self.logger("连接服务器最终失败", "ERROR")
                    return False

    def _ensure_remote_dir_exists(self, remote_dir):
        try:
            self.sftp.stat(remote_dir)
        except FileNotFoundError:
            self.sftp.mkdir(remote_dir)
            self.logger(f"创建远程目录: {remote_dir}", "INFO")

    def _get_latest_local_folder(self, path):
        try:
            folders = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))]
            if not folders:
                raise Exception("未找到有效版本文件夹")
            latest = max(folders, key=lambda f: os.path.getmtime(os.path.join(path, f)))
            return os.path.join(path, latest)
        except Exception as e:
            self.logger(f"获取最新文件夹失败: {e}", "ERROR")
            return None

    def _get_latest_matching_file(self, folder, pattern):
        files = [os.path.join(folder, f) for f in os.listdir(folder) if fnmatch.fnmatch(f, pattern)]
        if not files:
            return None
        latest = max(files, key=os.path.getmtime)
        return latest

    def _read_remote_file(self, remote_path):
        try:
            with self.sftp.open(remote_path, 'r') as fp:
                return fp.read().decode('utf-8')
        except Exception as e:
            self.logger(f"读取远程文件失败: {e}", "ERROR")
            return ""

    def _discover_backend_version(self, service_type):
        backend_config = self.env_config.get('version_discovery', {}).get('backend')
        if not backend_config:
            self.logger("版本发现配置中缺少 'backend'", "ERROR")
            return None

        config = backend_config.get(service_type)
        if not config:
            self.logger(f"找不到名为 {service_type} 的后端服务配置", "ERROR")
            return None

        search_path = config['search_path']

        # 获取本地最新版本文件夹
        latest_folder = self._get_latest_local_folder(search_path)
        if not latest_folder:
            return None

        # 提取本地版本号
        local_version = self.extract_backend_version(os.path.basename(latest_folder))

        # 获取远程 JAR 文件路径
        remote_jar_path = os.path.join(config['search_path'], config['file_pattern'].replace("*.", ""))
        if self._remote_file_exists(remote_jar_path):
            remote_jar_name = os.path.basename(remote_jar_path)
            remote_version = self.extract_backend_version(remote_jar_name)

            # 对比版本号
            if local_version <= remote_version:
                self.logger(f"本地版本 {local_version} 不高于远程版本 {remote_version}，跳过更新", "WARNING")
                return None

        # 获取 JAR 文件路径
        jar_file = self._get_latest_matching_file(latest_folder, config['file_pattern'])
        if not jar_file:
            self.logger(f"未找到匹配 {config['file_pattern']} 的 JAR 文件", "ERROR")
            return None

        return {'local_jar': jar_file, 'local_version': local_version}

    def _discover_frontend_version(self, frontend_type):
        frontend_config = self.env_config.get('version_discovery', {}).get('frontend')
        if not frontend_config:
            self.logger("版本发现配置中缺少 'frontend'", "ERROR")
            return None

        config = frontend_config.get(frontend_type)
        if not config:
            self.logger(f"找不到名为 {frontend_type} 的前端服务配置", "ERROR")
            return None

        search_path = config['search_path']

        # 获取本地最新版本文件夹
        latest_folder = self._get_latest_local_folder(search_path)
        if not latest_folder:
            return None

        # 提取本地版本号
        local_version = self.extract_frontend_version(os.path.basename(latest_folder))

        # 获取远程版本号（假设有一个 version.txt 文件记录版本号）
        deploy_config = self.env_config.get('deploy_config', {}).get(frontend_type)
        if not deploy_config:
            self.logger(f"找不到前端 {frontend_type} 的更新配置", "ERROR")
            return None

        remote_version_path = os.path.join(deploy_config['remote_dir'], 'version.txt')
        if self._remote_file_exists(remote_version_path):
            remote_version_str = self._read_remote_file(remote_version_path)
            remote_version = tuple(map(int, remote_version_str.strip().split('.')))

            # 对比版本号
            if local_version <= remote_version:
                self.logger(f"本地版本 {local_version} 不高于远程版本 {remote_version}，跳过更新", "WARNING")
                return None

        return {'local_dir': latest_folder, 'files': config['files_to_update'], 'local_version': local_version}

    def extract_backend_version(self, folder_name):
        match = re.search(r'V(\d+\.\d+\.\d+\.\d+)', folder_name)
        if match:
            return tuple(map(int, match.group(1).split('.')))
        return (0, 0, 0, 0)

    def extract_frontend_version(self, folder_name):
        match = re.search(r'(\d+\.\d+\.\d+\.\d+)', folder_name)
        if match:
            return tuple(map(int, match.group(1).split('.')))
        return (0, 0, 0, 0)

    def _execute_command(self, command):
        ignorable_patterns = self.env_config.get("ignorable_log_patterns", [])
        if ignorable_patterns:
            ignorable_patterns = [re.compile(p) for p in ignorable_patterns]

        try:
            stdin, stdout, stderr = self.ssh.exec_command(command)
            exit_code = stdout.channel.recv_exit_status()
            output = stdout.read().decode().strip()
            error = stderr.read().decode().strip()

            if output:
                self.logger(f"命令输出: {output}", "INFO")
            if error and exit_code != 0:
                for line in error.splitlines():
                    if any(p.search(line) for p in ignorable_patterns):
                        continue
                    self.logger(f"命令错误: {line}", "ERROR")
                return False
            return True
        except Exception as e:
            self.logger(f"命令执行失败: {e}", "ERROR")
            return False

    def _upload_jar(self, local_jar, remote_dir, jar_name):
        try:
            remote_path = os.path.join(remote_dir, jar_name)
            file_size = os.path.getsize(local_jar)
            self.logger(f"上传新版本至: {remote_path}")

            progress_bar = tqdm(total=file_size, unit='B', unit_scale=True, desc="JAR上传进度")

            def upload_progress_callback(sent, total):
                progress_bar.update(sent - progress_bar.n)

            self.sftp.put(local_jar, remote_path, callback=upload_progress_callback)
            progress_bar.close()
            return True
        except Exception as e:
            self.logger(f"上传失败: {e}", "ERROR")
            return False

    def _restart_backend_service(self, container, exec_cmd):
        try:
            cmd = f"docker exec -i {container} {exec_cmd}"
            self.logger(f"正在重启服务: {cmd}")
            if not self._execute_command(cmd):
                raise Exception("服务重启失败")
            return True
        except Exception as e:
            self.logger(f"重启服务失败: {e}", "ERROR")
            return False

    def _backup_and_update_backend(self, service_type, version_info):
        deploy_config = self.env_config.get('deploy_config', {})
        config = deploy_config.get(service_type)
        if not config:
            self.logger(f"找不到名为 {service_type} 的部署配置", "ERROR")
            return False

        remote_dir = config['remote_dir']
        backup_dir = config['backup_dir']
        jar_name = config['jar_name']
        local_version = version_info.get('local_version', (0, 0, 0, 0))

        try:
            self.logger(f"\n===== 开始更新 {service_type} 后端服务 (版本: {local_version}) =====", important=True)

            # 确保备份目录存在
            self._ensure_remote_dir_exists(backup_dir)

            # 获取远程路径
            current_jar = os.path.join(remote_dir, jar_name)

            # 第一步：先做备份（如果存在）
            if self._remote_file_exists(current_jar):
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                backup_jar = os.path.join(backup_dir, f"{jar_name}_{timestamp}.jar")
                self.logger(f"正在备份旧版本到: {backup_jar}")
                if not self._execute_command(f"cp {current_jar} {backup_jar}"):
                    raise Exception("备份失败")

            # 第二步：上传新 JAR 文件
            if not self._upload_jar(version_info['local_jar'], remote_dir, jar_name):
                raise Exception("上传失败")

            # 第三步：重启服务
            if not self._restart_backend_service(config['container'], config['exec_cmd']):
                raise Exception("服务重启失败")

            self.logger(f"{service_type} 后端更新成功 (版本: {local_version})", important=True)
            return True
        except Exception as e:
            self.logger(f"{service_type} 更新失败: {e}", "ERROR")
            return False

    def _remote_file_exists(self, remote_path):
        try:
            self.sftp.stat(remote_path)
            return True
        except:
            return False

    def _list_remote_files(self, remote_dir, pattern="*"):
        try:
            files = self.sftp.listdir(remote_dir)
            matched_files = [f for f in files if fnmatch.fnmatch(f, pattern)]
            return matched_files
        except Exception as e:
            self.logger(f"列出远程目录文件失败: {e}", "ERROR")
            return []

    def _create_frontend_backup(self, remote_dir, backup_dir):
        try:
            self._ensure_remote_dir_exists(backup_dir)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            backup_path = os.path.join(backup_dir, f"backup_{timestamp}")

            self.sftp.mkdir(backup_path)
            self.logger(f"创建备份目录: {backup_path}")

            # 1. 备份 index.html
            index_html_path = os.path.join(remote_dir, 'index.html')
            if self._remote_file_exists(index_html_path):
                self._execute_command(f"mv {index_html_path} {backup_path}/")
                self.logger(f"已备份 index.html 到 {backup_path}", "INFO")
            else:
                self.logger("未找到 index.html，跳过备份", "WARNING")

            # 2. 备份 static 目录
            static_path = os.path.join(remote_dir, 'static')
            if self._remote_file_exists(static_path):
                self._execute_command(f"mv {static_path} {backup_path}/")
                self.logger(f"已备份 static 目录到 {backup_path}", "INFO")
            else:
                self.logger("未找到 static 目录，跳过备份", "WARNING")

            # 3. 备份 .worker.js 文件（支持通配符）
            worker_files = self._list_remote_files(remote_dir, pattern="*.worker.js")
            for worker_file in worker_files:
                worker_path = os.path.join(remote_dir, worker_file)
                self._execute_command(f"mv {worker_path} {backup_path}/")
                self.logger(f"已备份 worker 文件: {worker_file}", "INFO")

            return True
        except Exception as e:
            self.logger(f"前端备份失败: {e}", "ERROR")
            return False

    def _write_remote_file(self, remote_path, content):
        try:
            with self.sftp.open(remote_path, 'w') as fp:
                fp.write(content.encode('utf-8'))
            return True
        except Exception as e:
            self.logger(f"写入远程文件失败: {e}", "ERROR")
            return False

    def _upload_frontend(self, local_dir, remote_dir):
        tmp_archive = os.path.join(tempfile.gettempdir(), "frontend.tar.gz")
        try:
            # 打包前端文件
            with tarfile.open(tmp_archive, "w:gz") as tar:
                for root, dirs, files in os.walk(local_dir):
                    arcname = os.path.relpath(root, local_dir)
                    tar.add(root, arcname=arcname)

            # 获取本地文件大小用于进度条
            file_size = os.path.getsize(tmp_archive)
            remote_archive = "/tmp/" + os.path.basename(tmp_archive)

            self.logger(f"正在上传文件到远程服务器: {remote_archive}")

            # 定义进度条回调函数
            progress_bar = tqdm(total=file_size, unit='B', unit_scale=True, desc="上传进度")

            def upload_progress_callback(sent, total):
                progress_bar.update(sent - progress_bar.n)

            # 使用带有进度条的上传
            self.sftp.put(tmp_archive, remote_archive, callback=upload_progress_callback)
            progress_bar.close()

            # 远程解压并清理
            self._execute_command(f"tar -xzf {remote_archive} -C {remote_dir} && rm -f {remote_archive}")
            os.unlink(tmp_archive)
            return True
        except Exception as e:
            self.logger(f"前端上传失败: {e}", "ERROR")
            return False

    def _update_frontend(self, frontend_type):
        deploy_config = self.env_config.get('deploy_config', {})
        config = deploy_config.get(frontend_type)
        if not config:
            self.logger(f"找不到名为 {frontend_type} 的前端部署配置", "ERROR")
            return False

        remote_dir = config['remote_dir']
        backup_dir = config['backup_dir']

        try:
            self.logger(f"\n===== 开始更新 {frontend_type} 前端 =====", important=True)

            # 1. 获取最新前端文件
            version_info = self._discover_frontend_version(frontend_type)
            if not version_info:
                raise Exception("未找到可用版本")

            local_dir = version_info['local_dir']
            local_version = version_info.get('local_version', (0, 0, 0, 0))
            self.logger(f"使用版本路径: {local_dir}")

            # 第一步：创建备份
            if not self._create_frontend_backup(remote_dir, backup_dir):
                raise Exception("备份失败")

            # 第二步：上传并解压新版本
            if not self._upload_frontend(local_dir, remote_dir):
                raise Exception("上传失败")

            # 第三步：更新版本号文件
            version_file_path = os.path.join(remote_dir, 'version.txt')
            self._write_remote_file(version_file_path, '.'.join(map(str, local_version)))

            self.logger(f"{frontend_type} 前端更新成功 (版本: {local_version})", important=True)
            return True
        except Exception as e:
            self.logger(f"{frontend_type} 前端更新失败: {e}", "ERROR")
            return False

    def deploy(self):
        try:
            if not self.connect():
                return False

            # 更新后端
            for service in ['inner', 'external']:
                info = self._discover_backend_version(service)
                if not info:
                    self.logger(f"{service} 服务更新跳过", "WARNING")
                    continue
                if not self._backup_and_update_backend(service, info):
                    return False

            # 更新前端
            for frontend in ['front', 'admin']:
                if not self._update_frontend(frontend):
                    return False

            self.logger("\n✅ 全栈更新完成！", important=True)
            return True
        except Exception as e:
            self.logger(f"主流程出错: {e}", "ERROR")
            return False
        finally:
            self._cleanup()

    def _cleanup(self):
        if self.sftp:
            self.sftp.close()
        if self.ssh:
            self.ssh.close()
        self.logger("资源清理完成")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="预定系统更新工具")
    parser.add_argument("--env", choices=["LZSH_Test", "Test_235"], default="LZSH_Test",
                        help="选择更新环境 (LZSH_Test/Test_235)")
    args = parser.parse_args()

    try:
        env_config = load_config(args.env)
    except Exception as e:
        print(f"\033[1;31m❌ 配置加载失败: {e}\033[0m")
        sys.exit(1)

    print(f"\n=== 预定系统服务更新工具 [{args.env.upper()} 环境] ===")

    ssh_config = env_config['ssh']
    if ssh_config['password'] in ("hzpassw0RD@KP", "prod_default_pass"):
        print("\033[1;31m! 安全警告: 您正在使用默认密码!\033[0m")

    deployer = Deployer(env_config)
    success = deployer.deploy()
    if success:
        print("\n\033[1;32m✅ 全栈更新成功!\033[0m")
        sys.exit(0)
    else:
        print("\n\033[1;31m❌ 更新失败，请检查日志。\033[0m")
        sys.exit(1)