import { useFrame, useThree } from "@react-three/fiber";
import {
  Vector2,
  Vector3,
  InstancedBufferGeometry,
  BoxBufferGeometry,
  MeshStandardMaterial,
  InstancedBufferAttribute,
} from "three";
import WebGPURenderer from "three/examples/jsm/renderers/webgpu/WebGPURenderer";
import WebGPUStorageBuffer from "three/examples/jsm/renderers/webgpu/WebGPUStorageBuffer";
import WebGPUUniformsGroup from "three/examples/jsm/renderers/webgpu/WebGPUUniformsGroup";
import { Vector2Uniform } from "three/examples/jsm/renderers/webgpu/WebGPUUniform.js";

import {
  PositionNode,
  OperatorNode,
  AttributeNode,
  FloatNode,
} from "./import-helper";
import { useState } from "react";

const particleNum = 256 * 256;
const particleSize = 3;

console.log("particleNum", particleNum);

const threadGroupSize = [256, 1, 1];

const computeShader = require("./compute-points.glsl")
  .default.replaceAll("__PARTICLE_NUM__", particleNum)
  .replaceAll("__PARTICLE_SIZE__", particleSize)
  .replaceAll("__LOCAL_SIZE_X__", threadGroupSize[0])
  .replaceAll("__LOCAL_SIZE_Y__", threadGroupSize[1])
  .replaceAll("__LOCAL_SIZE_Z__", threadGroupSize[2]);

export function ComputePoints() {
  const [{ computeParam, particleBuffer, geometry }] = useState(() => {
    const particleArray = new Float32Array(particleNum * particleSize);
    const velocityArray = new Float32Array(particleNum * particleSize);

    for (let i = 0; i < particleArray.length; i += 3) {
      const velocity = new Vector3(
        Math.random() - 0.5,
        Math.random() - 0.5,
        Math.random() - 0.5
      ).multiplyScalar(0.01);

      velocityArray[i + 0] = velocity.x;
      velocityArray[i + 1] = velocity.y;
      velocityArray[i + 2] = velocity.z;
    }

    const particleBuffer = new WebGPUStorageBuffer(
      "particle",
      new InstancedBufferAttribute(particleArray, 3)
    );
    const velocityBuffer = new WebGPUStorageBuffer(
      "velocity",
      new InstancedBufferAttribute(velocityArray, 3)
    );

    const pointer = new Vector2(0.5, 0.5); // Out of bounds first

    const pointerGroup = new WebGPUUniformsGroup("mouseUniforms").addUniform(
      new Vector2Uniform("pointer", pointer)
    );

    const computeBindings = [particleBuffer, velocityBuffer, pointerGroup];

    // const geometry = new InstancedBufferGeometry().setAttribute(
    //   "positionTest",
    //   particleBuffer.attribute
    // );

    const boxSize = 0.1;
    const boxGeometry = new BoxBufferGeometry(boxSize, boxSize, boxSize);

    const geometry = new InstancedBufferGeometry();
    geometry
      .setAttribute("position", boxGeometry.getAttribute("position"))
      .setAttribute("instancePosition", particleBuffer.attribute)
      .setAttribute("normal", boxGeometry.getAttribute("normal"));
    // .setAttribute( 'instanceColor', colorAttribute );

    geometry.setIndex(boxGeometry.getIndex());
    geometry.instanceCount = particleNum;

    return {
      computeParam: {
        num: particleNum / threadGroupSize[0],
        shader: computeShader,
        bindings: computeBindings,
      },
      particleBuffer,
      geometry,
    };
  });

  const material = new MeshStandardMaterial();
  (material as any).colorNode = new OperatorNode(
    "*",
    new AttributeNode("instancePosition", "vec3"),
    new FloatNode(0.3)
  );
  // (material as any).colorNode = new ColorNode(new Color(0xff0000));
  (material as any).positionNode = new OperatorNode(
    "+",
    new PositionNode(),
    new AttributeNode("instancePosition", "vec3")
  );

  (material as any).attributes = [
    new AttributeNode("instancePosition", "vec3"),
  ];

  console.log((material as any).attributes);

  const { gl } = useThree();

  useFrame(() => {
    (gl as WebGPURenderer).compute([computeParam]);
  });

  return (
    <mesh geometry={geometry} material={material} frustumCulled={false}></mesh>
  );
}
