package internal

import (
	"agent/commons/bytes"
	"agent/commons/debug"
	"agent/defines/derrs"
	"agent/internal/include"
	"agent/modules/ports/api"
	"agent/modules/ports/defines/merrs"
	"context"
	"fmt"
	"io"
	"net"
	"sync"
	"sync/atomic"
	"time"

	"github.com/google/uuid"
)

func NewTunnelPortOpen(cl include.IClient, params api.PortParams) (*tunPortOpen, error) {
	if params.Host == "" {
		params.Host = "127.0.0.1"
	}
	if params.Port <= 0 {
		return nil, derrs.NewIncorrectParamsError("port <= 0")
	}
	if params.LinkTo.Host == "" {
		return nil, derrs.NewIncorrectParamsError("linkTo.Host is empty")
	}
	if params.LinkTo.Port == 0 {
		return nil, derrs.NewIncorrectParamsError("linkTo.Port == 0")
	}
	if len(params.Handshake.Key) == 0 {
		return nil, derrs.NewIncorrectParamsError("handshake.key is empty")
	}
	if len(params.Handshake.Value) == 0 {
		return nil, derrs.NewIncorrectParamsError("handshake.value is empty")
	}
	return &tunPortOpen{
		params: params,
	}, nil
}

type tunPortOpen struct {
	params    api.PortParams
	lock      sync.Mutex
	runOnce   atomic.Bool
	srvConn   net.Conn
	agentConn net.Conn
}

func (c *tunPortOpen) Params() api.PortParams {
	return c.params
}

func (c *tunPortOpen) Close() error {
	go func() {
		c.lock.Lock()
		defer c.lock.Unlock()
		if c.params.Uuid != uuid.Nil {
			Ports.Delete(c.params.Uuid)
		}
		if c.srvConn != nil {
			c.srvConn.Close()
		}
		if c.agentConn != nil {
			c.agentConn.Close()
		}
	}()
	return nil
}

func (c *tunPortOpen) Run(ctx context.Context) (err error) {
	if c.runOnce.CompareAndSwap(false, true) {
		if _, ok := Ports.Load(c.params.Uuid); ok {
			// Порт уже запущен
			return nil
		}

		Ports.Store(c.params.Uuid, c)
		defer func() {
			if c.params.Uuid != uuid.Nil {
				Ports.Delete(c.params.Uuid)
			}
		}()

		for {
			if _, ok := Ports.Load(c.params.Uuid); !ok {
				// Порт был закрыт
				return nil
			}
			repeat, err := c.createTunnel(ctx)
			if err != nil {
				debug.Log(err)
			}
			if !repeat {
				debug.Logf(`Tunnel closed for port "%d" on host "%s"`, c.params.Port, c.params.Host)
				return err
			}
		}
	}
	return nil
}

func (c *tunPortOpen) createTunnel(ctx context.Context) (repeat bool, err error) {
	srvConn, err := c.connectToSrv(c.params.LinkTo.Host, c.params.LinkTo.Port)
	if err != nil {
		return false, err
	}
	defer srvConn.Close()

	sharedKey := c.params.Handshake.Key
	signature := c.params.Handshake.Value
	xorKey := bytes.Random(len(sharedKey))

	if err = c.writeHandshake(srvConn, xorKey, sharedKey, signature); err != nil {
		return false, err
	}

	return c.link(srvConn, xorKey, c.params)
}

func (c *tunPortOpen) writeHandshake(srvConn net.Conn, xorKey []byte, sharedKey []byte, signature []byte) error {
	// Если handshake не будет отправлен за указанное время, то открытие порта завершится с ошибкой
	if err := srvConn.SetWriteDeadline(time.Now().Add(WriteTimeout)); err != nil {
		return merrs.Join(err, merrs.NewFailedSetWriteDeadlineError(err.Error()))
	}

	obfSignature := make([]byte, len(signature))
	copy(obfSignature, signature)
	for i := range obfSignature {
		obfSignature[i] ^= xorKey[i%len(xorKey)]
	}

	obfXorKey := make([]byte, len(xorKey))
	copy(obfXorKey, xorKey)
	for i := range obfXorKey {
		obfXorKey[i] ^= sharedKey[i%len(sharedKey)]
	}

	handshake := make([]byte, len(obfXorKey)+len(obfSignature))
	copy(handshake, obfXorKey)
	copy(handshake[len(obfXorKey):], obfSignature)

	if _, err := srvConn.Write(handshake); err != nil {
		return merrs.Join(err, merrs.NewFailedSendHandshakeError(err.Error()))
	}

	// Сбрасываем таймаут, чтобы не мешал дальнейшей работе
	srvConn.SetWriteDeadline(time.Time{})

	return nil
}

func (c *tunPortOpen) link(srvConn net.Conn, xorKey []byte, params api.PortParams) (repeat bool, err error) {
	var buff = make([]byte, 256)
	// Читаем размер сигнального пакета, который сервер отправляет, когда он готов обмену данными
	if _, err := io.ReadFull(srvConn, buff[:1]); err != nil {
		return false, merrs.Join(err, merrs.NewFailedReadResponseError(err.Error()))
	}
	// Расшифровываем размер сигнального-пакета используя сгенерированный нами XorKey
	size := buff[0] ^ xorKey[0]
	// Вcе содержимое сигнального-пакета является мусором, поэтому просто читаем его в dev/null
	if _, err := io.CopyN(io.Discard, srvConn, int64(size)); err != nil {
		return false, merrs.Join(err, merrs.NewFailedReadResponseError(err.Error()))
	}
	// Устанавливаем связь с локальным портом, который нужно связать с соединением сервера
	agentConn, err := c.connectToPort(params.Host, params.Port)
	if err != nil {
		// Если попали сюда, то мы смогли успешно пройти все этапы открытия порта связанные с сервером, но не смогли
		// установить соединение с локальным портом (возможно, процесс управляемый данным портом не запущен). В этом
		// случае правильным решением будет повторить все действия сначала. Это приведет к:
		// - разрыву текущего соединения между сервером и оператором;
		// - повторному ожиданию переподключения оператора к серверу на стороне агента.
		// Так как связь будет разорвана, у оператора появится возможность разобраться в ситуации и предпринять
		// необходимые действия, прежде чем снова повторять попытку подключения к порту.

		// 1. Здесь можно добавить отправку ошибки на сервер
		return true, err
	}
	defer agentConn.Close()
	// 2. Здесь можно добавить отправку короткого ответа об успешном подключении
	// ---
	// Если возникнет надобность, то "1 и 2" нужно реализовывать одновременно

	go proxy(srvConn, agentConn)
	err = proxy(agentConn, srvConn)
	return true, err
}

func (c *tunPortOpen) connectToSrv(host string, port int) (srvConn net.Conn, err error) {
	c.lock.Lock()
	defer c.lock.Unlock()
	c.srvConn, err = connectTo(host, port)
	if err != nil {
		return nil, merrs.Join(err, merrs.NewFailedConnectToError(fmt.Sprintf("%s:%d", host, port)))
	}
	return c.srvConn, nil
}

func (c *tunPortOpen) connectToPort(host string, port int) (srvConn net.Conn, err error) {
	c.lock.Lock()
	defer c.lock.Unlock()
	c.agentConn, err = connectTo(host, port, 3)
	if err != nil {
		return nil, merrs.Join(err, merrs.NewFailedConnectToError(fmt.Sprintf("%s:%d", host, port)))
	}
	return c.agentConn, nil
}
