# To run in CCB:
#   module load python-cbrg
#   module load cuda/12.2
#   python ./cuda-example.py

import sys, time, os, subprocess, hashlib
import numpy as np

from numba import cuda, uint8, uint64
# Numba = just-in-time compiler.
# Write your arbitrary function in CUDA-aware Python, 
# and Numba can convert to Nvidia PTX code

import cupy
# CuPy = NumPy for CUDA
# Array-based interface to CUDA

fa_filepath = "/databank/igenomes/Bos_taurus/Ensembl/ARS-UCD1.3/Sequence/cds/Bos_taurus.ARS-UCD1.3.cds.all.fa"

def read_fasta_ACGT(path: str) -> np.ndarray:
    # Read bytes, drop headers/newlines, keep only A/C/G/T (case-insensitive).
    data = bytearray()
    with open(path, "rb") as f:
        for line in f:
            if line.startswith(b">"):
                continue
            for b in line:
                if b in (65,67,71,84,97,99,103,116):  # A C G T a c g t
                    # normalize to uppercase
                    data.append(b & ~0x20)  # ASCII trick: clear lowercase bit
    # Map ASCII to 0..3 (A/C/G/T) or 255 for ignore (shouldn't happen after filter)
    lut = np.full(256, 255, dtype=np.uint8)
    lut[ord('A')] = 0; lut[ord('C')] = 1; lut[ord('G')] = 2; lut[ord('T')] = 3
    arr = np.frombuffer(bytes(data), dtype=np.uint8)
    return lut[arr]  # compact sequence: values 0..3

@cuda.jit
def count_bases_kernel(seq, n, counts4):
    # counts4[0]=A, [1]=C, [2]=G, [3]=T
    i = cuda.grid(1)
    stride = cuda.gridsize(1)

    a = uint64(0); c = uint64(0); g = uint64(0); t = uint64(0)
    for idx in range(i, n, stride):
        v = seq[idx]
        if v == 0:
            a += 1
        elif v == 1:
            c += 1
        elif v == 2:
            g += 1
        elif v == 3:
            t += 1

    if a: cuda.atomic.add(counts4, 0, a)
    if c: cuda.atomic.add(counts4, 1, c)
    if g: cuda.atomic.add(counts4, 2, g)
    if t: cuda.atomic.add(counts4, 3, t)

def count_bases(seq_gpu, n, counts_gpu):
    threads = 256
    blocks = min(65535, (n + threads - 1) // threads)
    count_bases_kernel[blocks, threads](seq_gpu, n, counts_gpu)
    cuda.synchronize()

def compute_user_signature_cupy(user_id):
    # To prove this user ran on a GPU,
    # do an arbitrary GPU calculation on their username.
    x = hashlib.sha256(user_id.encode("utf-8")).digest()

    # create a GPU array view of the bytes, convert to float32
    x_gpu = cupy.frombuffer(x, dtype=cupy.uint8).astype(cupy.float32)   # shape (32,)
    x_gpu = (x_gpu / 255.0) - 0.5  # normalize to roughly [-0.5, 0.5]

    # cheap but GPU-resident operations
    transformed = cupy.sin(x_gpu * 12.345) + cupy.tanh(x_gpu * 3.21)
    weights = cupy.arange(1, x_gpu.size + 1, dtype=cupy.float32)

    # weighted sum (all on GPU)
    result = cupy.sum(transformed * weights)

    # return md5 hash of the GPU result to avoid sharing raw float values
    result_float = float(result.item())
    result_bytes = f"{result_float:.12f}".encode("utf-8")
    sig = hashlib.md5(result_bytes).hexdigest()[:8]
    return sig

def main():
    # Load a FASTA file into CPU memory
    print("Loading data ...")
    seq_host = read_fasta_ACGT(fa_filepath)

    pid = os.getpid()
    job = os.getenv('SLURM_JOB_ID')
    sig = compute_user_signature_cupy(subprocess.check_output(["id"], text=True).strip())
    print(f"Run signature: PID={pid}, job={job}, user={sig}")
    
    print("Starting GPU work")
    
    for i in range(2000):
        # Runs at 19 passes per second, so 2000 ~= 2 minutes.
        # Enough time to run "nvidia-smi" during.

        # Request CPU to copy sequence into GPU
        seq_gpu = cuda.to_device(seq_host)

        # Create array in GPU for storing 4x counts
        counts_gpu = cuda.to_device(np.zeros(4, np.uint64))
        # - alternative with CuPy:
        #counts_gpu = cupy.zeros(4, dtype=cupy.uint64)

        # Calculate in GPU
        count_bases(seq_gpu, seq_host.size, counts_gpu)

        # Request CPU to get counts from GPU
        counts_host = counts_gpu.copy_to_host()

    # Print statistics
    A, C, G, T = counts_host.tolist()
    total = int(A + C + G + T)
    gc = (100.0 * (G + C) / total) if total else 0.0
    print(f"A: {A}\nC: {C}\nG: {G}\nT: {T}\nGC%: {gc}")
    
    print(f"Run signature: PID={pid}, job={job}, user={sig}")

if __name__ == "__main__":
    main()
