Path: blob/master/finrl/meta/env_cryptocurrency_trading/env_multiple_crypto.py
732 views
from __future__ import annotations12import numpy as np345class CryptoEnv: # custom env6def __init__(7self,8config,9lookback=1,10initial_capital=1e6,11buy_cost_pct=1e-3,12sell_cost_pct=1e-3,13gamma=0.99,14):15self.lookback = lookback16self.initial_total_asset = initial_capital17self.initial_cash = initial_capital18self.buy_cost_pct = buy_cost_pct19self.sell_cost_pct = sell_cost_pct20self.max_stock = 121self.gamma = gamma22self.price_array = config["price_array"]23self.tech_array = config["tech_array"]24self._generate_action_normalizer()25self.crypto_num = self.price_array.shape[1]26self.max_step = self.price_array.shape[0] - lookback - 12728# reset29self.time = lookback - 130self.cash = self.initial_cash31self.current_price = self.price_array[self.time]32self.current_tech = self.tech_array[self.time]33self.stocks = np.zeros(self.crypto_num, dtype=np.float32)3435self.total_asset = self.cash + (self.stocks * self.price_array[self.time]).sum()36self.episode_return = 0.037self.gamma_return = 0.03839"""env information"""40self.env_name = "MulticryptoEnv"41self.state_dim = (421 + (self.price_array.shape[1] + self.tech_array.shape[1]) * lookback43)44self.action_dim = self.price_array.shape[1]45self.if_discrete = False46self.target_return = 104748def reset(49self,50*,51seed=None,52options=None,53) -> np.ndarray:54self.time = self.lookback - 155self.current_price = self.price_array[self.time]56self.current_tech = self.tech_array[self.time]57self.cash = self.initial_cash # reset()58self.stocks = np.zeros(self.crypto_num, dtype=np.float32)59self.total_asset = self.cash + (self.stocks * self.price_array[self.time]).sum()6061state = self.get_state()62return state6364def step(self, actions) -> (np.ndarray, float, bool, None):65self.time += 16667price = self.price_array[self.time]68for i in range(self.action_dim):69norm_vector_i = self.action_norm_vector[i]70actions[i] = actions[i] * norm_vector_i7172for index in np.where(actions < 0)[0]: # sell_index:73if price[index] > 0: # Sell only if current asset is > 074sell_num_shares = min(self.stocks[index], -actions[index])75self.stocks[index] -= sell_num_shares76self.cash += price[index] * sell_num_shares * (1 - self.sell_cost_pct)7778for index in np.where(actions > 0)[0]: # buy_index:79if (80price[index] > 081): # Buy only if the price is > 0 (no missing data in this particular date)82buy_num_shares = min(83self.cash // (price[index] * (1 + self.buy_cost_pct)),84actions[index],85)86self.stocks[index] += buy_num_shares87self.cash -= price[index] * buy_num_shares * (1 + self.buy_cost_pct)8889"""update time"""90done = self.time == self.max_step91state = self.get_state()92next_total_asset = self.cash + (self.stocks * self.price_array[self.time]).sum()93reward = (next_total_asset - self.total_asset) * 2**-1694self.total_asset = next_total_asset95self.gamma_return = self.gamma_return * self.gamma + reward96self.cumu_return = self.total_asset / self.initial_cash97if done:98reward = self.gamma_return99self.episode_return = self.total_asset / self.initial_cash100return state, reward, done, None101102def get_state(self):103state = np.hstack((self.cash * 2**-18, self.stocks * 2**-3))104for i in range(self.lookback):105tech_i = self.tech_array[self.time - i]106normalized_tech_i = tech_i * 2**-15107state = np.hstack((state, normalized_tech_i)).astype(np.float32)108return state109110def close(self):111pass112113def _generate_action_normalizer(self):114action_norm_vector = []115price_0 = self.price_array[0]116for price in price_0:117x = len(str(price)) - 7118action_norm_vector.append(1 / ((10) ** x))119120self.action_norm_vector = np.asarray(action_norm_vector)121122123