// !!! This is a file automatically generated by hipify!!!
#pragma once

#include <c10/core/DeviceGuard.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

#include <c10/hip/HIPCachingAllocator.h>
#include <c10/hip/HIPException.h>
#include <c10/hip/HIPFunctions.h>
#include <c10/hip/HIPStream.h>

#include <hip/hip_runtime_api.h>

namespace c10 {
namespace hip {
namespace impl {

struct HIPGuardImpl final : public c10::impl::DeviceGuardImplInterface {
  static constexpr DeviceType static_type = DeviceType::HIP;

  HIPGuardImpl() = default;
  explicit HIPGuardImpl(DeviceType t) {
    TORCH_INTERNAL_ASSERT(t == DeviceType::HIP);
  }
  DeviceType type() const override {
    return DeviceType::HIP;
  }
  Device exchangeDevice(Device d) const override {
    TORCH_INTERNAL_ASSERT(d.is_hip());
    Device old_device = getDevice();
    if (old_device.index() != d.index()) {
      C10_HIP_CHECK(hipSetDevice(d.index()));
    }
    return old_device;
  }
  Device getDevice() const override {
    int device;
    C10_HIP_CHECK(hipGetDevice(&device));
    return Device(DeviceType::HIP, device);
  }
  c10::optional<Device> uncheckedGetDevice() const noexcept {
    int device;
    const auto err = C10_HIP_ERROR_HANDLED(hipGetDevice(&device));
    C10_HIP_CHECK_WARN(err);
    if (err != hipSuccess) {
      return c10::nullopt;
    }
    return Device(DeviceType::HIP, device);
  }
  void setDevice(Device d) const override {
    TORCH_INTERNAL_ASSERT(d.is_hip());
    Device current_device = getDevice();
    if (current_device != d) {
      C10_HIP_CHECK(hipSetDevice(d.index()));
    }
  }
  void uncheckedSetDevice(Device d) const noexcept override {
    auto current_device = uncheckedGetDevice();
    if (!current_device.has_value() || current_device.value() != d) {
      C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
    }
  }
  Stream getStream(Device d) const noexcept override {
    return getCurrentHIPStream(d.index()).unwrap();
  }
  Stream getDefaultStream(Device d) const override {
    return getDefaultHIPStream(d.index());
  }
  Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
      const override {
    return getStreamFromPool(isHighPriority, d.index());
  }
  // NB: These do NOT set the current device
  Stream exchangeStream(Stream s) const noexcept override {
    HIPStream cs(s);
    auto old_stream = getCurrentHIPStream(s.device().index());
    setCurrentHIPStream(cs);
    return old_stream.unwrap();
  }
  DeviceIndex deviceCount() const noexcept override {
    return device_count();
  }

  // Event-related functions
  void createEvent(hipEvent_t* hip_event, const EventFlag flag) const {
    // Maps PyTorch's Event::Flag to HIP flag
    auto hip_flag = hipEventDefault;
    switch (flag) {
      case EventFlag::PYTORCH_DEFAULT:
      case EventFlag::HIP_EVENT_DISABLE_TIMING:
        hip_flag = hipEventDisableTiming;
        break;
      case EventFlag::BACKEND_DEFAULT:
      case EventFlag::HIP_EVENT_DEFAULT:
        hip_flag = hipEventDefault;
        break;
      default:
        TORCH_CHECK(false, "HIP event received unknown flag");
    }

    C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_creation(
          reinterpret_cast<uintptr_t>(hip_event));
    }
  }

  void destroyEvent(void* event, const DeviceIndex device_index)
      const noexcept override {
    if (!event)
      return;
    auto hip_event = static_cast<hipEvent_t>(event);
    int orig_device;
    C10_HIP_CHECK_WARN(hipGetDevice(&orig_device));
    C10_HIP_CHECK_WARN(hipSetDevice(device_index));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_deletion(
          reinterpret_cast<uintptr_t>(hip_event));
    }
    C10_HIP_CHECK_WARN(hipEventDestroy(hip_event));
    C10_HIP_CHECK_WARN(hipSetDevice(orig_device));
  }

  void record(
      void** event,
      const Stream& stream,
      const DeviceIndex device_index,
      const EventFlag flag) const override {
    TORCH_CHECK(
        device_index == -1 || device_index == stream.device_index(),
        "Event device index ",
        device_index,
        " does not match recording stream's device index ",
        stream.device_index(),
        ".");

    hipEvent_t hip_event = static_cast<hipEvent_t>(*event);
    HIPStream hip_stream{stream};

    // Moves to stream's device to record
    const auto orig_device = getDevice();
    setDevice(stream.device());

    // Creates the event (lazily)
    if (!hip_event)
      createEvent(&hip_event, flag);
    C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream));
    // Makes the void* point to the (possibly just allocated) HIP event
    *event = hip_event;
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_record(
          reinterpret_cast<uintptr_t>(hip_event),
          reinterpret_cast<uintptr_t>(hip_stream.stream()));
    }

    // Resets device
    setDevice(orig_device);
  }

  void block(void* event, const Stream& stream) const override {
    if (!event)
      return;
    hipEvent_t hip_event = static_cast<hipEvent_t>(event);
    HIPStream hip_stream{stream};
    const auto orig_device = getDevice();
    setDevice(stream.device());
    C10_HIP_CHECK(hipStreamWaitEvent(
        hip_stream,
        hip_event,
        /*flags (must be zero)=*/0));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_wait(
          reinterpret_cast<uintptr_t>(hip_event),
          reinterpret_cast<uintptr_t>(hip_stream.stream()));
    }
    setDevice(orig_device);
  }

  // May be called from any device
  bool queryEvent(void* event) const override {
    if (!event)
      return true;
    hipEvent_t hip_event = static_cast<hipEvent_t>(event);
    const hipError_t err = C10_HIP_ERROR_HANDLED(hipEventQuery(hip_event));
    if (err != hipErrorNotReady) {
      C10_HIP_CHECK(err);
    } else {
      // ignore and clear the error if not ready
      (void)hipGetLastError();
    }
    return (err == hipSuccess);
  }

  // Stream-related functions
  bool queryStream(const Stream& stream) const override {
    HIPStream hip_stream{stream};
    return hip_stream.query();
  }

  void synchronizeStream(const Stream& stream) const override {
    HIPStream hip_stream{stream};
    hip_stream.synchronize();
  }

  void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
      const override {
    HIPStream hip_stream{stream};
    HIPCachingAllocator::recordStream(data_ptr, hip_stream);
  }
};

} // namespace impl
} // namespace hip
} // namespace c10
