import random
from typing import Dict

from meraki_Interface_forward.redis_utils import get_redis_client


def collect_prefix_ttl_stats(pattern: str, sample_size: int = 10, max_scan_keys: int = 200) -> Dict[str, object]:
    """
    收集指定前缀的 key 数量与 TTL 统计。
    - 精确 count：全量 scan 计数（仅保存前 max_scan_keys 个用于采样）
    - TTL 采样：最多 sample_size 个，使用 pipeline 批量查询
    """
    client = get_redis_client()
    if not client:
        return {"count": 0, "sampled": 0, "minTtl": None, "avgTtl": None}

    keys = []
    count = 0
    cursor = 0
    try:
        while True:
            cursor, batch = client.scan(cursor=cursor, match=pattern, count=500)
            batch = batch or []
            count += len(batch)
            # 仅保留前 max_scan_keys 个用于采样
            for k in batch:
                if len(keys) < max_scan_keys:
                    keys.append(k)
            if cursor == 0:
                break
    except Exception:
        return {"count": 0, "sampled": 0, "minTtl": None, "avgTtl": None}

    if count == 0 or not keys:
        return {"count": count, "sampled": 0, "minTtl": None, "avgTtl": None}

    # 随机采样（进一步减少 TTL 查询）
    actual_sample_size = min(len(keys), sample_size)
    sampled_keys = random.sample(keys, actual_sample_size)

    # 使用 pipeline 批量查询 TTL（一次性查询所有采样 key）
    ttls = []
    try:
        pipe = client.pipeline()
        for k in sampled_keys:
            pipe.ttl(k)
        results = pipe.execute()
        for t in results:
            if isinstance(t, int) and t >= 0:
                ttls.append(t)
    except Exception:
        # pipeline 失败时回退到单个查询（但这种情况应该很少）
        for k in sampled_keys[:5]:  # 只查询前 5 个，避免超时
            try:
                t = client.ttl(k)
                if isinstance(t, int) and t >= 0:
                    ttls.append(t)
            except Exception:
                continue

    return {
        "count": count,
        "sampled": len(sampled_keys),
        "minTtl": min(ttls) if ttls else None,
        "avgTtl": round(sum(ttls) / len(ttls), 2) if ttls else None,
    }


def build_prefix_stats(prefix_patterns: Dict[str, str], sample_size: int = 10) -> Dict[str, dict]:
    """
    生成按前缀的 TTL 统计结果。
    - 精确 count（全量 scan）
    - device 前缀：采样 5，保留 200 个键用于采样
    - 其他前缀：采样 5，保留 100 个键用于采样
    """
    result = {}
    for name, pattern in prefix_patterns.items():
        if name == "device":
            result[name] = collect_prefix_ttl_stats(pattern, sample_size=5, max_scan_keys=200)
        else:
            result[name] = collect_prefix_ttl_stats(pattern, sample_size=5, max_scan_keys=100)
    return result

