#include "CudaCommon.hpp"
#include "java/com_ibm_cuda_Cuda.h"
#include "java/com_ibm_cuda_Cuda_Cleaner.h"
void
throwCudaException(JNIEnv * env, int32_t error)
{
J9VMThread * thread = (J9VMThread *)env;
Trc_cuda_throwCudaException_entry(thread, error);
J9JavaVM * javaVM = thread->javaVM;
J9CudaGlobals * globals = javaVM->cudaGlobals;
Assert_cuda_true(NULL != globals);
Assert_cuda_true(NULL != globals->exception_init);
if (env->ExceptionCheck()) {
Trc_cuda_throwCudaException_suppressed(thread);
} else {
jobject exception = env->NewObject(globals->exceptionClass, globals->exception_init, error);
if (NULL != exception) {
env->Throw((jthrowable)exception);
}
}
Trc_cuda_throwCudaException_exit(thread);
}
jlong JNICALL
Java_com_ibm_cuda_Cuda_allocatePinnedBuffer
(JNIEnv * env, jclass, jlong byteCount)
{
J9VMThread * thread = (J9VMThread *)env;
Trc_cuda_allocatePinnedBuffer_entry(thread, byteCount);
void * address = NULL;
int32_t error = J9CUDA_ERROR_NO_DEVICE;
#ifdef OMR_OPT_CUDA
PORT_ACCESS_FROM_ENV(env);
error = j9cuda_hostAlloc((size_t)byteCount, J9CUDA_HOST_ALLOC_DEFAULT, &address);
#endif
if (0 != error) {
throwCudaException(env, error);
}
Trc_cuda_allocatePinnedBuffer_exit(thread, address);
return (jlong)address;
}
jint JNICALL
Java_com_ibm_cuda_Cuda_getDeviceCount
(JNIEnv * env, jclass)
{
J9VMThread * thread = (J9VMThread *)env;
Trc_cuda_getDeviceCount_entry(thread);
uint32_t count = 0;
#ifdef OMR_OPT_CUDA
PORT_ACCESS_FROM_ENV(env);
int32_t error = j9cuda_deviceGetCount(&count);
if (0 != error) {
throwCudaException(env, error);
}
#endif
Trc_cuda_getDeviceCount_exit(thread, count);
return (jint)count;
}
jint JNICALL
Java_com_ibm_cuda_Cuda_getDriverVersion
(JNIEnv * env, jclass)
{
J9VMThread * thread = (J9VMThread *)env;
Trc_cuda_getDriverVersion_entry(thread);
uint32_t version = 0;
int32_t error = J9CUDA_ERROR_NO_DEVICE;
#ifdef OMR_OPT_CUDA
PORT_ACCESS_FROM_ENV(env);
error = j9cuda_driverGetVersion(&version);
#endif
if (0 != error) {
throwCudaException(env, error);
}
Trc_cuda_getDriverVersion_exit(thread, version);
return (jint)version;
}
jstring JNICALL
Java_com_ibm_cuda_Cuda_getErrorMessage
(JNIEnv * env, jclass, jint code)
{
J9VMThread * thread = (J9VMThread *)env;
Trc_cuda_getErrorMessage_entry(thread, code);
const char * message = NULL;
#ifdef OMR_OPT_CUDA
PORT_ACCESS_FROM_ENV(env);
message = j9cuda_getErrorString(code);
#else
switch (code) {
case J9CUDA_NO_ERROR:
message = "no error";
break;
case J9CUDA_ERROR_MEMORY_ALLOCATION:
message = "memory allocation failed";
break;
case J9CUDA_ERROR_NO_DEVICE:
message = "no CUDA-capable device is detected";
break;
default:
break;
}
#endif
Trc_cuda_getErrorMessage_exit(thread, (NULL == message) ? "(null)" : message);
return (NULL == message) ? NULL : env->NewStringUTF(message);
}
jint JNICALL
Java_com_ibm_cuda_Cuda_getRuntimeVersion
(JNIEnv * env, jclass)
{
J9VMThread * thread = (J9VMThread *)env;
Trc_cuda_getRuntimeVersion_entry(thread);
uint32_t version = 0;
int32_t error = J9CUDA_ERROR_NO_DEVICE;
#ifdef OMR_OPT_CUDA
PORT_ACCESS_FROM_ENV(env);
error = j9cuda_runtimeGetVersion(&version);
#endif
if (0 != error) {
throwCudaException(env, error);
}
Trc_cuda_getRuntimeVersion_exit(thread, version);
return (jint)version;
}
jint JNICALL
Java_com_ibm_cuda_Cuda_initialize
(JNIEnv * env, jclass, jclass exceptionClass, jobject runMethod)
{
jint result = 0;
J9CudaGlobals * globals = NULL;
J9VMThread * thread = (J9VMThread *)env;
J9JavaVM * javaVM = thread->javaVM;
UT_MODULE_LOADED(J9_UTINTERFACE_FROM_VM(javaVM));
Trc_cuda_initialize_entry(thread);
Assert_cuda_true(NULL != exceptionClass);
Assert_cuda_true(NULL != runMethod);
PORT_ACCESS_FROM_JAVAVM(javaVM);
globals = (J9CudaGlobals *)J9CUDA_ALLOCATE_MEMORY(sizeof(J9CudaGlobals));
if (NULL == globals) {
Trc_cuda_initialize_fail(thread, "allocate globals");
result = J9CUDA_ERROR_MEMORY_ALLOCATION;
goto done;
}
memset(globals, 0, sizeof(J9CudaGlobals));
globals->exceptionClass = (jclass)env->NewGlobalRef(exceptionClass);
if (NULL == globals->exceptionClass) {
Trc_cuda_initialize_fail(thread, "create NewGlobalRef for CudaException");
result = J9CUDA_ERROR_MEMORY_ALLOCATION;
goto error1;
}
globals->exception_init = env->GetMethodID(globals->exceptionClass, "<init>", "(I)V");
if (NULL == globals->exception_init) {
Trc_cuda_initialize_fail(thread, "find CudaException constructor");
result = J9CUDA_ERROR_INITIALIZATION_ERROR;
goto error2;
}
globals->runnable_run = env->FromReflectedMethod(runMethod);
if (NULL == globals->runnable_run) {
Trc_cuda_initialize_fail(thread, "get method handle");
result = J9CUDA_ERROR_INITIALIZATION_ERROR;
error2:
env->DeleteGlobalRef(globals->exceptionClass);
error1:
J9CUDA_FREE_MEMORY(globals);
globals = NULL;
}
javaVM->cudaGlobals = globals;
done:
Trc_cuda_initialize_exit(thread, result);
return result;
}
#ifdef OMR_OPT_CUDA
jobject JNICALL
Java_com_ibm_cuda_Cuda_wrapDirectBuffer
(JNIEnv * env, jclass, jlong buffer, jlong capacity)
{
J9VMThread * thread = (J9VMThread *)env;
Trc_cuda_wrapDirectBuffer_entry(thread, (uintptr_t)buffer, capacity);
jobject wrapper = env->NewDirectByteBuffer((void *)(uintptr_t)buffer, capacity);
Trc_cuda_wrapDirectBuffer_exit(thread, wrapper);
return wrapper;
}
void JNICALL
Java_com_ibm_cuda_Cuda_00024Cleaner_releasePinnedBuffer
(JNIEnv * env, jclass, jlong address)
{
J9VMThread * thread = (J9VMThread *)env;
Trc_cuda_releasePinnedBuffer_entry(thread, (uintptr_t)address);
PORT_ACCESS_FROM_ENV(env);
int32_t error = j9cuda_hostFree((void *)(uintptr_t)address);
if (0 != error) {
throwCudaException(env, error);
}
Trc_cuda_releasePinnedBuffer_exit(thread);
}
#endif
void JNICALL
JNI_OnUnload(JavaVM * jvm, void *)
{
J9JavaVM * javaVM = (J9JavaVM *)jvm;
J9CudaGlobals * globals = javaVM->cudaGlobals;
PORT_ACCESS_FROM_JAVAVM(javaVM);
if (NULL != globals) {
if (NULL != globals->exceptionClass) {
JNIEnv * env = NULL;
if (JNI_OK == jvm->GetEnv((void **)&env, JNI_VERSION_1_2)) {
env->DeleteGlobalRef(globals->exceptionClass);
}
}
J9CUDA_FREE_MEMORY(globals);
javaVM->cudaGlobals = NULL;
}
UT_MODULE_UNLOADED(J9_UTINTERFACE_FROM_VM(javaVM));
}