Path: blob/master/finrl/meta/env_cryptocurrency_trading/env_btc_ccxt.py
732 views
from __future__ import annotations12import numpy as np345class BitcoinEnv: # custom env6def __init__(7self,8data_cwd=None,9price_ary=None,10tech_ary=None,11time_frequency=15,12start=None,13mid1=172197,14mid2=216837,15end=None,16initial_account=1e6,17max_stock=1e2,18transaction_fee_percent=1e-3,19mode="train",20gamma=0.99,21):22self.stock_dim = 123self.initial_account = initial_account24self.transaction_fee_percent = transaction_fee_percent25self.max_stock = 126self.gamma = gamma27self.mode = mode28self.load_data(29data_cwd, price_ary, tech_ary, time_frequency, start, mid1, mid2, end30)3132# reset33self.day = 034self.initial_account__reset = self.initial_account35self.account = self.initial_account__reset36self.day_price = self.price_ary[self.day]37self.day_tech = self.tech_ary[self.day]38self.stocks = 0.0 # multi-stack3940self.total_asset = self.account + self.day_price[0] * self.stocks41self.episode_return = 0.042self.gamma_return = 0.04344"""env information"""45self.env_name = "BitcoinEnv4"46self.state_dim = 1 + 1 + self.price_ary.shape[1] + self.tech_ary.shape[1]47self.action_dim = 148self.if_discrete = False49self.target_return = 1050self.max_step = self.price_ary.shape[0]5152def reset(53self,54*,55seed=None,56options=None,57) -> np.ndarray:58self.day = 059self.day_price = self.price_ary[self.day]60self.day_tech = self.tech_ary[self.day]61self.initial_account__reset = self.initial_account # reset()62self.account = self.initial_account__reset63self.stocks = 0.064self.total_asset = self.account + self.day_price[0] * self.stocks6566normalized_tech = [67self.day_tech[0] * 2**-1,68self.day_tech[1] * 2**-15,69self.day_tech[2] * 2**-15,70self.day_tech[3] * 2**-6,71self.day_tech[4] * 2**-6,72self.day_tech[5] * 2**-15,73self.day_tech[6] * 2**-15,74]75state = np.hstack(76(77self.account * 2**-18,78self.day_price * 2**-15,79normalized_tech,80self.stocks * 2**-4,81)82).astype(np.float32)83return state8485def step(self, action) -> (np.ndarray, float, bool, None):86stock_action = action[0]87"""buy or sell stock"""88adj = self.day_price[0]89if stock_action < 0:90stock_action = max(910, min(-1 * stock_action, 0.5 * self.total_asset / adj + self.stocks)92)93self.account += adj * stock_action * (1 - self.transaction_fee_percent)94self.stocks -= stock_action95elif stock_action > 0:96max_amount = self.account / adj97stock_action = min(stock_action, max_amount)98self.account -= adj * stock_action * (1 + self.transaction_fee_percent)99self.stocks += stock_action100101"""update day"""102self.day += 1103self.day_price = self.price_ary[self.day]104self.day_tech = self.tech_ary[self.day]105done = (self.day + 1) == self.max_step106normalized_tech = [107self.day_tech[0] * 2**-1,108self.day_tech[1] * 2**-15,109self.day_tech[2] * 2**-15,110self.day_tech[3] * 2**-6,111self.day_tech[4] * 2**-6,112self.day_tech[5] * 2**-15,113self.day_tech[6] * 2**-15,114]115state = np.hstack(116(117self.account * 2**-18,118self.day_price * 2**-15,119normalized_tech,120self.stocks * 2**-4,121)122).astype(np.float32)123124next_total_asset = self.account + self.day_price[0] * self.stocks125reward = (next_total_asset - self.total_asset) * 2**-16126self.total_asset = next_total_asset127128self.gamma_return = self.gamma_return * self.gamma + reward129if done:130reward += self.gamma_return131self.gamma_return = 0.0132self.episode_return = next_total_asset / self.initial_account133return state, reward, done, None134135def draw_cumulative_return(self, args, _torch) -> list:136state_dim = self.state_dim137action_dim = self.action_dim138139agent = args.agent140net_dim = args.net_dim141cwd = args.cwd142143agent.init(net_dim, state_dim, action_dim)144agent.save_load_model(cwd=cwd, if_save=False)145act = agent.act146device = agent.device147148state = self.reset()149episode_returns = list()150episode_returns.append(1)151btc_returns = list() # the cumulative_return / initial_account152with _torch.no_grad():153for i in range(self.max_step):154if i == 0:155init_price = self.day_price[0]156btc_returns.append(self.day_price[i] / init_price)157s_tensor = _torch.as_tensor((state,), device=device)158a_tensor = act(s_tensor) # action_tanh = act.forward()159action = (160a_tensor.detach().cpu().numpy()[0]161) # not need detach(), because with torch.no_grad() outside162state, reward, done, _ = self.step(action)163164episode_returns.append(self.total_asset / 1e6)165if done:166break167168import matplotlib.pyplot as plt169170plt.plot(episode_returns, label="agent return")171plt.plot(btc_returns, color="yellow", label="BTC return")172plt.grid()173plt.title("cumulative return")174plt.xlabel("day")175plt.xlabel("multiple of initial_account")176plt.legend()177plt.savefig(f"{cwd}/cumulative_return.jpg")178return episode_returns, btc_returns179180def load_data(181self, data_cwd, price_ary, tech_ary, time_frequency, start, mid1, mid2, end182):183if data_cwd is not None:184try:185price_ary = np.load(f"{data_cwd}/price_ary.npy")186tech_ary = np.load(f"{data_cwd}/tech_ary.npy")187except BaseException:188raise ValueError("Data files not found!")189else:190price_ary = price_ary191tech_ary = tech_ary192193n = price_ary.shape[0]194if self.mode == "train":195self.price_ary = price_ary[start:mid1]196self.tech_ary = tech_ary[start:mid1]197n = self.price_ary.shape[0]198x = n // int(time_frequency)199ind = [int(time_frequency) * i for i in range(x)]200self.price_ary = self.price_ary[ind]201self.tech_ary = self.tech_ary[ind]202elif self.mode == "test":203self.price_ary = price_ary[mid1:mid2]204self.tech_ary = tech_ary[mid1:mid2]205n = self.price_ary.shape[0]206x = n // int(time_frequency)207ind = [int(time_frequency) * i for i in range(x)]208self.price_ary = self.price_ary[ind]209self.tech_ary = self.tech_ary[ind]210elif self.mode == "trade":211self.price_ary = price_ary[mid2:end]212self.tech_ary = tech_ary[mid2:end]213n = self.price_ary.shape[0]214x = n // int(time_frequency)215ind = [int(time_frequency) * i for i in range(x)]216self.price_ary = self.price_ary[ind]217self.tech_ary = self.tech_ary[ind]218else:219raise ValueError("Invalid Mode!")220221222