Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
wiseplat
GitHub Repository: wiseplat/python-code
Path: blob/master/ invest-robot-contest_tinkoffSDK-master/my_strategy.py
5925 views
1
import logging
2
from datetime import datetime, timedelta
3
from decimal import Decimal
4
from typing import Callable, Iterable, List
5
from typing_extensions import Self
6
import uuid
7
8
import numpy as np
9
10
from tinkoff.invest.strategies.base.account_manager import AccountManager
11
from tinkoff.invest.strategies.base.errors import (
12
CandleEventForDateNotFound,
13
NotEnoughData,
14
OldCandleObservingError,
15
)
16
from tinkoff.invest.strategies.base.models import CandleEvent
17
from tinkoff.invest.strategies.base.signal import (
18
CloseLongMarketOrder,
19
CloseShortMarketOrder,
20
OpenLongMarketOrder,
21
OpenShortMarketOrder,
22
Signal,
23
)
24
from tinkoff.invest.strategies.base.strategy_interface import InvestStrategy
25
from tinkoff.invest.strategies.moving_average.strategy_settings import (
26
MovingAverageStrategySettings,
27
)
28
from tinkoff.invest.strategies.moving_average.strategy_state import (
29
MovingAverageStrategyState,
30
)
31
from tinkoff.invest.utils import (
32
candle_interval_to_timedelta,
33
ceil_datetime,
34
floor_datetime,
35
now,
36
)
37
38
logger = logging.getLogger(__name__)
39
40
41
class MovingAverageStrategy(InvestStrategy):
42
def __init__(
43
self,
44
settings: MovingAverageStrategySettings,
45
account_manager: AccountManager,
46
state: MovingAverageStrategyState,
47
):
48
self._data: List[CandleEvent] = []
49
self._settings = settings
50
self._account_manager = account_manager
51
52
self._state = state
53
self._MA_LONG_START: Decimal
54
self._candle_interval_timedelta = candle_interval_to_timedelta(
55
self._settings.candle_interval
56
)
57
58
def _ensure_enough_candles(self) -> None:
59
candles_needed = (
60
self._settings.short_period + self._settings.long_period
61
) / self._settings.candle_interval_timedelta
62
if candles_needed > len(self._data):
63
#raise NotEnoughData()
64
logger.info("NotEnoughData for strategy")
65
else:
66
logger.info("Got enough data for strategy")
67
68
def fit(self, candles: Iterable[CandleEvent]) -> None:
69
logger.debug("Strategy fitting with candles %s", candles)
70
for candle in candles:
71
self.observe(candle)
72
self._ensure_enough_candles()
73
74
def _append_candle_event(self, candle_event: CandleEvent) -> None:
75
last_candle_event = self._data[-1]
76
last_interval_floor = floor_datetime(
77
last_candle_event.time, self._candle_interval_timedelta
78
)
79
last_interval_ceil = ceil_datetime(
80
last_candle_event.time, self._candle_interval_timedelta
81
)
82
83
if candle_event.time < last_interval_floor:
84
raise OldCandleObservingError()
85
if (
86
candle_event.time < last_interval_ceil
87
or candle_event.time == last_interval_floor
88
):
89
self._data[-1] = candle_event
90
else:
91
self._data.append(candle_event)
92
93
def observe(self, candle: CandleEvent) -> None:
94
logger.debug("Observing candle event: %s", candle)
95
96
if len(self._data) > 0:
97
self._append_candle_event(candle)
98
else:
99
self._data.append(candle)
100
101
@staticmethod
102
def _get_newer_than_datetime_predicate(
103
anchor: datetime,
104
) -> Callable[[CandleEvent], bool]:
105
def _(event: CandleEvent) -> bool:
106
return event.time > anchor
107
108
return _
109
110
def _filter_from_the_end_with_early_stop(
111
self, predicate: Callable[[CandleEvent], bool]
112
) -> Iterable[CandleEvent]:
113
for event in reversed(self._data):
114
if not predicate(event):
115
break
116
yield event
117
118
def _select_for_period(self, period: timedelta):
119
predicate = self._get_newer_than_datetime_predicate(now() - period)
120
return self._filter_from_the_end_with_early_stop(predicate)
121
122
@staticmethod
123
def _get_prices(events: Iterable[CandleEvent]) -> Iterable[Decimal]:
124
for event in events:
125
yield event.candle.close
126
127
def _calculate_moving_average(self, period: timedelta) -> Decimal:
128
prices = list(self._get_prices(self._select_for_period(period)))
129
logger.info("On %s second of work. Selected prices: %s", period.seconds,prices)
130
return np.mean(prices, axis=0) # type: ignore
131
132
def _calculate_std(self, period: timedelta) -> Decimal:
133
prices = list(self._get_prices(self._select_for_period(period)))
134
return np.std(prices, axis=0) # type: ignore
135
136
def _get_first_candle_before(self, date: datetime) -> CandleEvent:
137
predicate = self._get_newer_than_datetime_predicate(date)
138
for event in reversed(self._data):
139
if not predicate(event):
140
return event
141
raise CandleEventForDateNotFound()
142
143
def _init_MA_LONG_START(self):
144
date = now() - self._settings.short_period
145
event = self._get_first_candle_before(date)
146
self._MA_LONG_START = event.candle.close
147
148
@staticmethod
149
def _is_long_open_signal(
150
MA_SHORT: Decimal,
151
MA_LONG: Decimal,
152
PRICE: Decimal,
153
STD: Decimal,
154
MA_LONG_START: Decimal,
155
) -> bool:
156
logger.debug("Try long opening")
157
logger.debug("\tMA_SHORT > MA_LONG, %s", MA_SHORT > MA_LONG)
158
return (MA_SHORT > MA_LONG) #Если идет возрастающий тренд покупаем
159
160
@staticmethod
161
def _is_short_open_signal(
162
MA_SHORT: Decimal,
163
MA_LONG: Decimal,
164
PRICE: Decimal,
165
STD: Decimal,
166
MA_LONG_START: Decimal,
167
) -> bool:
168
logger.debug("Try short opening")
169
logger.debug("\tMA_SHORT < MA_LONG, %s", MA_SHORT < MA_LONG)
170
logger.debug(
171
"\tand abs((PRICE - MA_LONG) / MA_LONG) < STD, %s",
172
abs((PRICE - MA_LONG) / MA_LONG) < STD,
173
)
174
logger.debug("\tand MA_LONG > MA_LONG_START, %s", MA_LONG < MA_LONG_START)
175
logger.debug(
176
"== %s",
177
MA_SHORT < MA_LONG < MA_LONG_START
178
and abs((PRICE - MA_LONG) / MA_LONG) < STD,
179
)
180
return (
181
MA_SHORT < MA_LONG < MA_LONG_START
182
and abs((PRICE - MA_LONG) / MA_LONG) < STD
183
)
184
185
@staticmethod
186
def _is_long_close_signal(
187
MA_LONG: Decimal,
188
MA_SHORT: Decimal,
189
PRICE: Decimal,
190
STD: Decimal,
191
has_short_open_signal: bool
192
) -> bool:
193
logger.debug("Try long closing")
194
#logger.debug("\tPRICE > MA_LONG + 10 * STD, %s", PRICE > MA_LONG + 10 * STD)
195
#logger.debug("\tor has_short_open_signal, %s", has_short_open_signal)
196
#logger.debug("\tor PRICE < MA_LONG - 3 * STD, %s", PRICE < MA_LONG - 3 * STD)
197
logger.debug("MA_SHORT < MA_LONG %s",MA_SHORT < MA_LONG)
198
199
return (MA_SHORT < MA_LONG) #Если тренд на падение закрываем позицию
200
201
@staticmethod
202
def _is_short_close_signal(
203
MA_LONG: Decimal,
204
PRICE: Decimal,
205
STD: Decimal,
206
has_long_open_signal: bool,
207
) -> bool:
208
logger.debug("Try short closing")
209
logger.debug("\tPRICE < MA_LONG - 10 * STD, %s", PRICE < MA_LONG - 10 * STD)
210
logger.debug("\tor has_long_open_signal, %s", has_long_open_signal)
211
logger.debug("\tor PRICE > MA_LONG + 3 * STD, %s", PRICE > MA_LONG + 3 * STD)
212
logger.debug(
213
"== %s",
214
PRICE < MA_LONG - 10 * STD # кажется, что не работает закрытие
215
or has_long_open_signal
216
or PRICE > MA_LONG + 3 * STD,
217
)
218
return (
219
PRICE < MA_LONG - 10 * STD
220
or has_long_open_signal
221
or PRICE > MA_LONG + 3 * STD
222
)
223
224
def predict(self) -> Iterable[Signal]: # noqa: C901
225
logger.info("Strategy predict")
226
self._init_MA_LONG_START()
227
MA_LONG_START = self._MA_LONG_START
228
logger.debug("MA_LONG_START: %s", MA_LONG_START)
229
PRICE = self._data[-1].candle.close
230
logger.debug("PRICE: %s", PRICE)
231
MA_LONG = self._calculate_moving_average(self._settings.long_period)
232
logger.debug("MA_LONG: %s", MA_LONG)
233
MA_SHORT = self._calculate_moving_average(self._settings.short_period)
234
logger.debug("MA_SHORT: %s", MA_SHORT)
235
STD = self._calculate_std(self._settings.std_period)
236
logger.debug("STD: %s", STD)
237
MONEY,SHARES = self._account_manager.get_current_balance() #Получаем деньги на счете, количество акций не важно
238
logger.debug("MONEY: %s", MONEY)
239
240
has_long_open_signal = False
241
has_short_open_signal = False
242
243
lot_size = 10000 #Актуально, для акций ВТБ
244
#possible_lots = int(MONEY // (PRICE * lot_size))
245
possible_lots = 1 #Количество лотов можно рассчитывать, но для отработки робота, пока не будем.
246
247
248
if (
249
# not self._state.long_open
250
True #Упростим стратегию, не будем проверять срабатывание предыдущих сигналов
251
and self._is_long_open_signal(
252
MA_SHORT=MA_SHORT,
253
MA_LONG=MA_LONG,
254
PRICE=PRICE,
255
STD=STD,
256
MA_LONG_START=MA_LONG_START,
257
)
258
and possible_lots > 0
259
):
260
has_long_open_signal = True
261
yield OpenLongMarketOrder(lots=possible_lots)
262
263
if (
264
#not self._state.short_open
265
False # Коротки позиции запрещены
266
and self._is_short_open_signal(
267
MA_SHORT=MA_SHORT,
268
MA_LONG=MA_LONG,
269
PRICE=PRICE,
270
STD=STD,
271
MA_LONG_START=MA_LONG_START,
272
)
273
and possible_lots > 0
274
):
275
has_short_open_signal = True
276
yield OpenShortMarketOrder(lots=possible_lots)
277
278
#В конце дня длинные позиции должны быть закрыты
279
280
#my_stop_signal = self._data.pop().time > self._account_manager._services._real_market_data_test_start + timedelta(hours=6)
281
#logger.info("my_stop_signal: %s", my_stop_signal)
282
if self._state.long_open and self._is_long_close_signal(
283
#if self._is_long_close_signal(
284
MA_LONG=MA_LONG,
285
MA_SHORT = MA_SHORT,
286
PRICE=PRICE,
287
STD=STD,
288
has_short_open_signal=has_short_open_signal
289
#my_stop_signal = my_stop_signal
290
):
291
yield CloseLongMarketOrder(lots=self._state.position)
292
293
#Если короткие позиции не открывать, то и закрывать их не надо.
294
if self._state.short_open and self._is_short_close_signal(
295
MA_LONG=MA_LONG,
296
PRICE=PRICE,
297
STD=STD,
298
has_long_open_signal=has_long_open_signal,
299
):
300
yield CloseShortMarketOrder(lots=self._state.position)
301
302