#include "Metal/TestingEnvironmentMtl.hpp"
#include "Metal/TestingSwapChainMtl.hpp"

#include "DeviceContextMtl.h"
#include "TextureViewMtl.h"

#include "InlineShaders/ComputeShaderTestMSL.h"

namespace Diligent

namespace Testing

void ComputeShaderReferenceMtl(ISwapChain* pSwapChain)
    auto* const pEnv      = TestingEnvironmentMtl::GetInstance();
    auto const  mtlDevice = pEnv->GetMtlDevice();

        // Autoreleased
        auto* progSrc = [NSString stringWithUTF8String:MSL::FillTextureCS.c_str()];
        NSError *errors = nil; // Autoreleased
        id <MTLLibrary> library = [mtlDevice newLibraryWithSource:progSrc
        ASSERT_TRUE(library != nil);
        id <MTLFunction> computeFunc = [library newFunctionWithName:@"CSMain"];
        ASSERT_TRUE(computeFunc != nil);
        [library release];

        auto* computePipeline = [mtlDevice newComputePipelineStateWithFunction:computeFunc error:&errors];
        ASSERT_TRUE(computePipeline != nil);
        [computeFunc release];

        auto* pTestingSwapChainMtl = ValidatedCast<TestingSwapChainMtl>(pSwapChain);
        auto* pUAV = pTestingSwapChainMtl->GetCurrentBackBufferUAV();
        auto* mtlTexture = ValidatedCast<ITextureViewMtl>(pUAV)->GetMtlTexture();
        const auto& SCDesc = pTestingSwapChainMtl->GetDesc();

        auto* mtlCommandQueue = pEnv->GetMtlCommandQueue();

        // Command buffer is autoreleased
        id <MTLCommandBuffer> mtlCommandBuffer = [mtlCommandQueue commandBuffer];
        // Command encoder is autoreleased
        auto* cmdEncoder = [mtlCommandBuffer computeCommandEncoder];
        ASSERT_TRUE(cmdEncoder != nil);

        [cmdEncoder setComputePipelineState:computePipeline];
        [cmdEncoder setTexture:mtlTexture atIndex:0];
        [cmdEncoder dispatchThreadgroups:MTLSizeMake((SCDesc.Width + 15) / 16, (SCDesc.Height + 15) / 16, 1)
                   threadsPerThreadgroup:MTLSizeMake(16, 16, 1)];

        [cmdEncoder endEncoding];
        [mtlCommandBuffer commit];

} // namespace Testing

} // namespace Diligent