Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
wiseplat
GitHub Repository: wiseplat/python-code
Path: blob/master/ invest-robot-contest_investRobot-master/robotlib/strategy.py
5927 views
1
import datetime
2
import math
3
import random
4
5
from abc import ABC, abstractmethod
6
from dataclasses import dataclass, field
7
8
from tinkoff.invest import (
9
Candle,
10
HistoricCandle,
11
Instrument,
12
MarketDataResponse,
13
OrderType,
14
OrderDirection,
15
OrderState,
16
Quotation,
17
SubscriptionInterval,
18
)
19
from robotlib.money import Money
20
from robotlib.vizualization import Visualizer
21
22
23
@dataclass
24
class TradeStrategyParams:
25
instrument_balance: int
26
currency_balance: float
27
pending_orders: list[OrderState]
28
29
30
@dataclass
31
class RobotTradeOrder:
32
quantity: int
33
direction: OrderDirection
34
price: Money | None = None
35
order_type: OrderType = OrderType.ORDER_TYPE_MARKET
36
37
38
@dataclass
39
class StrategyDecision:
40
robot_trade_order: RobotTradeOrder | None = None
41
cancel_orders: list[OrderState] = field(default_factory=list)
42
43
44
class TradeStrategyBase(ABC):
45
instrument_info: Instrument
46
47
@property
48
@abstractmethod
49
def candle_subscription_interval(self) -> SubscriptionInterval:
50
return SubscriptionInterval.SUBSCRIPTION_INTERVAL_ONE_MINUTE
51
52
@property
53
@abstractmethod
54
def order_book_subscription_depth(self) -> int | None: # set not None to subscribe robot to order book
55
return None
56
57
@property
58
@abstractmethod
59
def trades_subscription(self) -> bool: # set True to subscribe robot to trades stream
60
return False
61
62
@property
63
@abstractmethod
64
def strategy_id(self) -> str:
65
"""
66
string representing short strategy name for logger
67
"""
68
raise NotImplementedError()
69
70
def load_instrument_info(self, instrument_info: Instrument):
71
self.instrument_info = instrument_info
72
73
def load_candles(self, candles: list[HistoricCandle]) -> None:
74
"""
75
Method used by robot to load historic data
76
"""
77
pass
78
79
@abstractmethod
80
def decide(self, market_data: MarketDataResponse, params: TradeStrategyParams) -> StrategyDecision:
81
if market_data.candle:
82
return self.decide_by_candle(market_data.candle, params)
83
return StrategyDecision()
84
85
@abstractmethod
86
def decide_by_candle(self, candle: Candle | HistoricCandle, params: TradeStrategyParams) -> StrategyDecision:
87
pass
88
89
90
class RandomStrategy(TradeStrategyBase):
91
request_candles: bool = True
92
strategy_id: str = 'random'
93
94
low: int
95
high: int
96
97
def __init__(self, low: int, high: int):
98
self.low = low
99
self.high = high
100
101
def decide(self, market_data: MarketDataResponse, params: TradeStrategyParams) -> StrategyDecision:
102
return self.decide_by_candle(market_data.candle, params)
103
104
def decide_by_candle(self, candle: Candle | HistoricCandle, params: TradeStrategyParams) -> StrategyDecision:
105
low = max(self.low, -params.instrument_balance)
106
high = min(self.high, math.floor(params.currency_balance / self.convert_quotation(candle.close)))
107
108
quantity = random.randint(low, high)
109
direction = OrderDirection.ORDER_DIRECTION_BUY if quantity > 0 else OrderDirection.ORDER_DIRECTION_SELL
110
111
return StrategyDecision(RobotTradeOrder(quantity=quantity, direction=direction))
112
113
@staticmethod
114
def convert_quotation(amount: Quotation) -> float | None:
115
if amount is None:
116
return None
117
return amount.units + amount.nano / (10 ** 9)
118
119
120
class MAEStrategy(TradeStrategyBase):
121
request_candles: bool = True
122
strategy_id: str = 'mae'
123
124
candle_subscription_interval: SubscriptionInterval = SubscriptionInterval.SUBSCRIPTION_INTERVAL_ONE_MINUTE
125
order_book_subscription_depth = None
126
trades_subscription = None
127
128
short_len: int
129
long_len: int
130
trade_count: int
131
prices = dict[datetime.datetime, Money]
132
prev_sign: bool
133
134
def __init__(self, short_len: int = 5, long_len: int = 20, trade_count: int = 1, visualizer: Visualizer = None):
135
assert long_len > short_len
136
self.short_len = short_len
137
self.long_len = long_len
138
self.trade_count = trade_count
139
self.prices = {}
140
self.visualizer = visualizer
141
142
def load_candles(self, candles: list[HistoricCandle]) -> None:
143
self.prices = {candle.time.replace(second=0, microsecond=0): Money(candle.close)
144
for candle in candles[-self.long_len:]}
145
self.prev_sign = self._long_avg() > self._short_avg()
146
147
def decide(self, market_data: MarketDataResponse, params: TradeStrategyParams) -> StrategyDecision:
148
return self.decide_by_candle(market_data.candle, params)
149
150
def decide_by_candle(self, candle: Candle | HistoricCandle, params: TradeStrategyParams) -> StrategyDecision:
151
time: datetime = candle.time.replace(second=0, microsecond=0)
152
order: RobotTradeOrder | None = None
153
if time not in self.prices: # make order only once a minute (when minutely candle is ready)
154
sign = self._long_avg() > self._short_avg()
155
if sign != self.prev_sign:
156
if sign:
157
if params.instrument_balance > 0:
158
order = RobotTradeOrder(quantity=min(self.trade_count, params.instrument_balance),
159
direction=OrderDirection.ORDER_DIRECTION_SELL)
160
if self.visualizer:
161
self.visualizer.add_sell(time)
162
else:
163
lot_price = Money(candle.close).to_float() * self.instrument_info.lot
164
lots_available = int(params.currency_balance / lot_price)
165
if params.currency_balance >= lot_price:
166
order = RobotTradeOrder(quantity=min(self.trade_count, lots_available),
167
direction=OrderDirection.ORDER_DIRECTION_BUY)
168
if self.visualizer:
169
self.visualizer.add_buy(time)
170
171
self.prev_sign = sign
172
self.prices[time] = Money(candle.close)
173
if self.visualizer:
174
self.visualizer.add_price(time, Money(candle.close).to_float())
175
self.visualizer.update_plot()
176
177
return StrategyDecision(robot_trade_order=order)
178
179
def get_prices_list(self) -> list[Money]:
180
# sort by keys and then convert to a list of values
181
return list(map(lambda x: x[1], sorted(self.prices.items(), key=lambda x: x[0])))
182
183
def _long_avg(self):
184
return sum(float(price) for price in self.get_prices_list()[-self.long_len:]) / self.long_len
185
186
def _short_avg(self):
187
return sum(float(price) for price in self.get_prices_list()[-self.short_len:]) / self.short_len
188
189