import socket
import random
import json
from threading import Thread

from network import step


def handle_client(conn, addr):
    """处理客户端连接"""
    try:
        # 接收客户端数据（可忽略内容）

        received_data = bytearray()
        while True:
            chunk = conn.recv(1024)
            if not chunk:
                break
            received_data.extend(chunk)

        info = json.loads(received_data.decode('utf-8'))

        array = step(5,3,info)
        array = array.tolist()
        array = array[0]
        # 序列化并发送
        response = json.dumps(array).encode('utf-8')
        conn.sendall(response)
        print(f"向 {addr} 发送: {response}")

    except Exception as e:
        print(f"客户端 {addr} 错误: {str(e)}")
    finally:
        conn.close()
def tensor_to_array(tensor):
    """安全转换张量为NumPy数组"""
    if tensor.requires_grad:
        tensor = tensor.detach()
    if tensor.device.type != 'cpu':
        tensor = tensor.cpu()
    return tensor.numpy()

def start_server(host='0.0.0.0', port=12345):
    """启动TCP服务器"""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind((host, port))
        s.listen()
        print(f"服务端启动，监听 {host}:{port}")

        # 持续接受新连接
        while True:
            conn, addr = s.accept()
            print(f"收到来自 {addr} 的连接")
            Thread(target=handle_client, args=(conn, addr)).start()


if __name__ == '__main__':
    start_server()