package utils

import (
	"agent/modules/console/defines/merrs"
	"fmt"
	"syscall"
	"unsafe"
)

const (
	TH32CS_SNAPPROCESS = 0x00000002
)

type processEntry32 struct {
	Size              uint32
	CntUsage          uint32
	ProcessID         uint32
	DefaultHeapID     uintptr
	ModuleID          uint32
	Threads           uint32
	ParentProcessID   uint32
	PriorityClassBase int32
	Flags             uint32
	ExeFile           [260]uint16
}

var (
	modKernel32                  = syscall.NewLazyDLL("kernel32.dll")
	procCreateToolhelp32Snapshot = modKernel32.NewProc("CreateToolhelp32Snapshot")
	procProcess32First           = modKernel32.NewProc("Process32FirstW")
	procProcess32Next            = modKernel32.NewProc("Process32NextW")
	procOpenProcess              = modKernel32.NewProc("OpenProcess")
	procTerminateProcess         = modKernel32.NewProc("TerminateProcess")
	procCloseHandle              = modKernel32.NewProc("CloseHandle")

	PROCESS_TERMINATE = 0x0001
)

func KillProcessTree(pid int64) error {
	children, _ := FindChildProcesses(pid)
	for _, child := range children {
		_ = KillProcessTree(child)
	}
	return KillSingleProcess(pid)
}

func FindChildProcesses(parentPID int64) (res []int64, err error) {
	defer func() {
		if err != nil {
			err = merrs.Join(err, merrs.NewFailedListProcessError())
		}
	}()

	snapshot, _, err := procCreateToolhelp32Snapshot.Call(TH32CS_SNAPPROCESS, 0)
	if int(snapshot) < 0 {
		return res, err
	}
	defer procCloseHandle.Call(snapshot)

	var entry processEntry32
	entry.Size = uint32(unsafe.Sizeof(entry))

	ret, _, err := procProcess32First.Call(snapshot, uintptr(unsafe.Pointer(&entry)))
	if err != nil {
		return nil, err
	}

	for ret != 0 {
		if int64(entry.ParentProcessID) == parentPID {
			res = append(res, int64(entry.ProcessID))
		}
		if ret, _, err = procProcess32Next.Call(snapshot, uintptr(unsafe.Pointer(&entry))); err != nil {
			return res, err
		}
	}
	return res, nil
}

func KillSingleProcess(pid int64) (err error) {
	defer func() {
		if err != nil {
			err = merrs.Join(err, merrs.NewFailedKillProcessError(fmt.Sprintf(`pid "%d"`, pid)))
		}
	}()

	h, _, err := procOpenProcess.Call(uintptr(PROCESS_TERMINATE), 0, uintptr(pid))
	if h == 0 {
		return merrs.Join(err, merrs.NewFailedOpenProcessError())
	}
	defer procCloseHandle.Call(h)

	r, _, err := procTerminateProcess.Call(h, 1)
	if r == 0 {
		return merrs.Join(err, merrs.NewFailedTreminateProcessError())
	}
	return nil
}

func KillChildProcess(pid int64) error {
	children, _ := FindChildProcesses(pid)
	for _, child := range children {
		_ = KillProcessTree(child)
	}
	return nil
}
