Path: blob/master/jcl/src/openj9.cuda/share/classes/com/ibm/cuda/CudaModule.java
12917 views
/*[INCLUDE-IF Sidecar18-SE]*/1/*******************************************************************************2* Copyright (c) 2013, 2021 IBM Corp. and others3*4* This program and the accompanying materials are made available under5* the terms of the Eclipse Public License 2.0 which accompanies this6* distribution and is available at https://www.eclipse.org/legal/epl-2.0/7* or the Apache License, Version 2.0 which accompanies this distribution and8* is available at https://www.apache.org/licenses/LICENSE-2.0.9*10* This Source Code may also be made available under the following11* Secondary Licenses when the conditions for such availability set12* forth in the Eclipse Public License, v. 2.0 are satisfied: GNU13* General Public License, version 2 with the GNU Classpath14* Exception [1] and GNU General Public License, version 2 with the15* OpenJDK Assembly Exception [2].16*17* [1] https://www.gnu.org/software/classpath/license.html18* [2] http://openjdk.java.net/legal/assembly-exception.html19*20* SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 OR GPL-2.0 WITH Classpath-exception-2.0 OR LicenseRef-GPL-2.0 WITH Assembly-exception21*******************************************************************************/22package com.ibm.cuda;2324import java.io.IOException;25import java.io.InputStream;26import java.util.HashMap;27import java.util.Map;28import java.util.concurrent.atomic.AtomicLong;2930import com.ibm.cuda.internal.CudaUtil;3132/**33* The {@code CudaModule} class represents a module that has been loaded34* on a CUDA-capable device.35* <p>36* When no longer required, a module must be unloaded (see {@link #unload()}).37*/38public final class CudaModule {3940/**41* The {@code Cache} class provides a simple mechanism to avoid reloading42* modules repeatedly. The set of loaded modules is specific to each device43* so two pieces of identification are required for each module: the device44* and a user-supplied key.45* <p>46* Note: Because this class is implemented with {@link HashMap}, keys47* must implement {@link #equals(Object)} and {@link #hashCode()}.48*/49public static final class Cache {5051private final Map<Object, Map<CudaDevice, CudaModule>> store;5253/**54* Creates a new cache.55*/56public Cache() {57super();58this.store = new HashMap<>(1);59}6061/**62* Retrieves an existing module for the specified device and key.63*64* @param device65* the specified device66* @param key67* the specified key68* @return69* return the module associated with the given key on the70* specified device, or null if no such module exists71*/72public CudaModule get(CudaDevice device, Object key) {73Map<?, CudaModule> map = store.get(key);7475return map == null ? null : map.get(device);76}7778/**79* Stores a module in this cache, associating it with the given80* device and key.81*82* @param device83* the specified device84* @param key85* the specified key86* @param module87* the module to be stored88* @return89* the module previously associated with the given key on90* the specified device, or null if no such module exists91*/92public CudaModule put(CudaDevice device, Object key, CudaModule module) {93Map<CudaDevice, CudaModule> map = store.get(key);9495if (map == null) {96store.put(key, map = new HashMap<>());97}9899return map.put(device, module);100}101}102103private static native long getFunction(int deviceId, long moduleHandle,104String name) throws CudaException;105106private static native long getGlobal(int deviceId, long moduleHandle,107String name) throws CudaException;108109private static native long getSurface(int deviceId, long moduleHandle,110String name) throws CudaException;111112private static native long getTexture(int deviceId, long moduleHandle,113String name) throws CudaException;114115private static native long load(int deviceId, byte[] image,116long optionsHandle) throws CudaException;117118private static native void unload(int deviceId, long moduleHandle)119throws CudaException;120121final int deviceId;122123private final Map<String, CudaFunction> functions;124125private final Map<String, CudaGlobal> globals;126127private final AtomicLong nativeHandle;128129private final Map<String, CudaSurface> surfaces;130131private final Map<String, CudaTexture> textures;132133/**134* Loads a module on the specified device, using the given image and the135* default options.136*137* @param device138* the specified device139* @param image140* the module image141* @throws CudaException142* if a CUDA exception occurs143* @throws SecurityException144* if a security manager exists and the calling thread145* does not have permission to load GPU modules146*/147public CudaModule(CudaDevice device, byte[] image) throws CudaException {148this(device, image, null);149}150151/**152* Loads a module on the specified device, using the given image and the153* given options.154*155* @param device156* the specified device157* @param image158* the module image159* @param options160* the desired options161* @throws CudaException162* if a CUDA exception occurs163* @throws SecurityException164* if a security manager exists and the calling thread165* does not have permission to load GPU modules166*/167public CudaModule(CudaDevice device, byte[] image, CudaJitOptions options)168throws CudaException {169super();170171@SuppressWarnings("removal")172SecurityManager security = System.getSecurityManager();173174if (security != null) {175security.checkPermission(CudaPermission.LoadModule);176}177178if (image == null) {179throw new NullPointerException();180}181182this.deviceId = device.getDeviceId();183184long optionsHandle = options == null ? 0 : options.getHandle();185186try {187this.functions = new HashMap<>();188this.globals = new HashMap<>();189this.nativeHandle = new AtomicLong( // <br/>190load(this.deviceId, image, optionsHandle));191this.surfaces = new HashMap<>();192this.textures = new HashMap<>();193} finally {194if (options != null) {195options.releaseHandle(true);196}197}198}199200/**201* Loads a module on the specified device from the given input stream using202* the default options.203*204* @param device205* the specified device206* @param input207* a stream containing the module image208* @throws CudaException209* if a CUDA exception occurs210* @throws IOException211* if an I/O error occurs reading {@code input}212* @throws SecurityException213* if a security manager exists and the calling thread214* does not have permission to load GPU modules215*/216public CudaModule(CudaDevice device, InputStream input)217throws CudaException, IOException {218this(device, input, null);219}220221/**222* Loads a module on the specified device from the given input stream using223* the specified options.224*225* @param device226* the specified device227* @param input228* a stream containing the module image229* @param options230* the desired options231* @throws CudaException232* if a CUDA exception occurs233* @throws IOException234* if an I/O error occurs reading {@code input}235* @throws SecurityException236* if a security manager exists and the calling thread237* does not have permission to load GPU modules238*/239public CudaModule(CudaDevice device, InputStream input,240CudaJitOptions options) throws CudaException, IOException {241this(device, CudaUtil.read(input, true), options);242}243244/**245* Returns the function of the specified name from this module.246*247* @param name248* the link-name of the desired function249* @return250* the function of the specified name251* @throws CudaException252* if a CUDA exception occurs253* @throws IllegalStateException254* if this module has been unloaded (see {@link #unload()})255*/256public CudaFunction getFunction(String name) throws CudaException {257CudaFunction function = functions.get(name);258259if (function == null) {260long address = getFunction(deviceId, getHandle(), name);261262functions.put(name, function = new CudaFunction(deviceId, address));263}264265return function;266}267268/**269* Returns the global variable of the specified name from this module.270*271* @param name272* the link-name of the desired global variable273* @return274* the global variable of the specified name275* @throws CudaException276* if a CUDA exception occurs277* @throws IllegalStateException278* if this module has been unloaded (see {@link #unload()})279*/280public CudaGlobal getGlobal(String name) throws CudaException {281CudaGlobal global = globals.get(name);282283if (global == null) {284long address = getGlobal(deviceId, getHandle(), name);285286globals.put(name, global = new CudaGlobal(address));287}288289return global;290}291292private long getHandle() {293long handle = nativeHandle.get();294295if (handle == 0) {296throw new IllegalStateException();297}298299return handle;300}301302/**303* Returns the surface of the specified name from this module.304*305* @param name306* the link-name of the desired surface307* @return308* the surface of the specified name309* @throws CudaException310* if a CUDA exception occurs311* @throws IllegalStateException312* if this module has been unloaded (see {@link #unload()})313*/314public CudaSurface getSurface(String name) throws CudaException {315CudaSurface surface = surfaces.get(name);316317if (surface == null) {318long address = getSurface(deviceId, getHandle(), name);319320surfaces.put(name, surface = new CudaSurface(address));321}322323return surface;324}325326/**327* Returns the texture of the specified name from this module.328*329* @param name330* the link-name of the desired texture331* @return332* the texture of the specified name333* @throws CudaException334* if a CUDA exception occurs335* @throws IllegalStateException336* if this module has been unloaded (see {@link #unload()})337*/338public CudaTexture getTexture(String name) throws CudaException {339CudaTexture texture = textures.get(name);340341if (texture == null) {342long address = getTexture(deviceId, getHandle(), name);343344textures.put(name, texture = new CudaTexture(address));345}346347return texture;348}349350/**351* Unloads this module from the associated device.352* <p>353* Note that this has no effect on any {@link Cache caches}.354* @throws CudaException355* if a CUDA exception occurs356*/357public void unload() throws CudaException {358long handle = nativeHandle.getAndSet(0);359360if (handle != 0) {361functions.clear();362globals.clear();363surfaces.clear();364textures.clear();365unload(deviceId, handle);366}367}368}369370371