analyze.py 34.0 KB
Newer Older
Z
zengbin93 已提交
1 2
# coding: utf-8

Z
zengbin93 已提交
3
import warnings
4

Z
zengbin93 已提交
5 6 7 8
try:
    import talib as ta
except ImportError:
    from czsc import ta
9

Z
zengbin93 已提交
10 11 12
    ta_lib_hint = "没有安装 ta-lib !!! 请到 https://www.lfd.uci.edu/~gohlke/pythonlibs/#ta-lib " \
                  "下载对应版本安装,预计分析速度提升2倍"
    warnings.warn(ta_lib_hint)
Z
zengbin93 已提交
13
import pandas as pd
14
import numpy as np
Z
zengbin93 已提交
15
from datetime import datetime
Z
zengbin93 已提交
16
from czsc.utils import plot_ka
Z
zengbin93 已提交
17

Z
zengbin93 已提交
18 19
def find_zs(points):
    """输入笔或线段标记点,输出中枢识别结果"""
Z
zengbin93 已提交
20
    if len(points) < 5:
Z
zengbin93 已提交
21 22 23
        return []

    # 当输入为笔的标记点时,新增 xd 值
24
    for j, x in enumerate(points):
Z
zengbin93 已提交
25
        if x.get("bi", 0):
26
            points[j]['xd'] = x["bi"]
Z
zengbin93 已提交
27

28 29 30 31 32 33 34 35 36 37
    def __get_zn(zn_points_):
        """把与中枢方向一致的次级别走势类型称为Z走势段,按中枢中的时间顺序,
        分别记为Zn等,而相应的高、低点分别记为gn、dn"""
        if len(zn_points_) % 2 != 0:
            zn_points_ = zn_points_[:-1]

        if zn_points_[0]['fx_mark'] == "d":
            z_direction = "up"
        else:
            z_direction = "down"
38 39 40 41 42 43 44 45 46 47 48 49

        zn = []
        for i in range(0, len(zn_points_), 2):
            zn_ = {
                "start_dt": zn_points_[i]['dt'],
                "end_dt": zn_points_[i + 1]['dt'],
                "high": max(zn_points_[i]['xd'], zn_points_[i + 1]['xd']),
                "low": min(zn_points_[i]['xd'], zn_points_[i + 1]['xd']),
                "direction": z_direction
            }
            zn_['mid'] = zn_['low'] + (zn_['high'] - zn_['low']) / 2
            zn.append(zn_)
50 51
        return zn

Z
zengbin93 已提交
52 53 54 55 56 57 58 59 60
    k_xd = points
    k_zs = []
    zs_xd = []

    for i in range(len(k_xd)):
        if len(zs_xd) < 5:
            zs_xd.append(k_xd[i])
            continue
        xd_p = k_xd[i]
61 62
        zs_d = max([x['xd'] for x in zs_xd[:4] if x['fx_mark'] == 'd'])
        zs_g = min([x['xd'] for x in zs_xd[:4] if x['fx_mark'] == 'g'])
Z
zengbin93 已提交
63 64 65 66 67
        if zs_g <= zs_d:
            zs_xd.append(k_xd[i])
            zs_xd.pop(0)
            continue

Z
zengbin93 已提交
68
        # 定义四个指标,GG=max(gn),G=min(gn),D=max(dn),DD=min(dn),n遍历中枢中所有Zn。
69
        # 定义ZG=min(g1、g2), ZD=max(d1、d2),显然,[ZD,ZG]就是缠中说禅走势中枢的区间
Z
zengbin93 已提交
70
        if xd_p['fx_mark'] == "d" and xd_p['xd'] > zs_g:
71
            zn_points = zs_xd[3:]
Z
zengbin93 已提交
72 73 74 75 76 77 78 79
            # 线段在中枢上方结束,形成三买
            k_zs.append({
                'ZD': zs_d,
                "ZG": zs_g,
                'G': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'GG': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'D': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
                'DD': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
80
                'start_point': zs_xd[1],
Z
zengbin93 已提交
81
                'end_point': zs_xd[-2],
82
                "zn": __get_zn(zn_points),
Z
zengbin93 已提交
83 84 85
                "points": zs_xd,
                "third_buy": xd_p
            })
86
            zs_xd = []
Z
zengbin93 已提交
87
        elif xd_p['fx_mark'] == "g" and xd_p['xd'] < zs_d:
88
            zn_points = zs_xd[3:]
Z
zengbin93 已提交
89 90 91 92 93 94 95 96
            # 线段在中枢下方结束,形成三卖
            k_zs.append({
                'ZD': zs_d,
                "ZG": zs_g,
                'G': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'GG': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'D': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
                'DD': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
97
                'start_point': zs_xd[1],
Z
zengbin93 已提交
98
                'end_point': zs_xd[-2],
Z
zengbin93 已提交
99
                "points": zs_xd,
100
                "zn": __get_zn(zn_points),
Z
zengbin93 已提交
101 102
                "third_sell": xd_p
            })
103
            zs_xd = []
Z
zengbin93 已提交
104 105 106 107
        else:
            zs_xd.append(xd_p)

    if len(zs_xd) >= 5:
108 109
        zs_d = max([x['xd'] for x in zs_xd[:4] if x['fx_mark'] == 'd'])
        zs_g = min([x['xd'] for x in zs_xd[:4] if x['fx_mark'] == 'g'])
Z
zengbin93 已提交
110
        if zs_g > zs_d:
111
            zn_points = zs_xd[3:]
Z
zengbin93 已提交
112 113 114 115 116 117 118
            k_zs.append({
                'ZD': zs_d,
                "ZG": zs_g,
                'G': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'GG': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'D': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
                'DD': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
119
                'start_point': zs_xd[1],
Z
zengbin93 已提交
120
                'end_point': None,
121
                "zn": __get_zn(zn_points),
Z
zengbin93 已提交
122 123
                "points": zs_xd,
            })
Z
zengbin93 已提交
124 125
    return k_zs

Z
zengbin93 已提交
126

127
def has_gap(k1, k2, min_gap=0.002):
Z
zengbin93 已提交
128 129
    """判断 k1, k2 之间是否有缺口"""
    assert k2['dt'] > k1['dt']
130 131
    if k1['high'] < k2['low'] * (1-min_gap) \
            or k2['high'] < k1['low'] * (1-min_gap):
Z
zengbin93 已提交
132 133 134 135 136
        return True
    else:
        return False


137 138 139
def make_standard_seq(bi_seq):
    """计算标准特征序列

Z
zengbin93 已提交
140 141
    :param bi_seq: list of dict
        笔标记序列
142
    :return: list of dict
Z
zengbin93 已提交
143
        标准特征序列
144 145 146 147 148 149 150 151
    """
    if bi_seq[0]['fx_mark'] == 'd':
        direction = "up"
    elif bi_seq[0]['fx_mark'] == 'g':
        direction = "down"
    else:
        raise ValueError

152 153 154
    raw_seq = [{"start_dt": bi_seq[i]['dt'], "end_dt": bi_seq[i+1]['dt'],
                'high': max(bi_seq[i]['bi'], bi_seq[i + 1]['bi']),
                'low': min(bi_seq[i]['bi'], bi_seq[i + 1]['bi'])}
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
               for i in range(1, len(bi_seq), 2) if i <= len(bi_seq) - 2]

    seq = []
    for row in raw_seq:
        if not seq:
            seq.append(row)
            continue
        last = seq[-1]
        cur_h, cur_l = row['high'], row['low']
        last_h, last_l = last['high'], last['low']

        # 左包含 or 右包含
        if (cur_h <= last_h and cur_l >= last_l) or (cur_h >= last_h and cur_l <= last_l):
            seq.pop(-1)
            # 有包含关系,按方向分别处理
            if direction == "up":
                last_h = max(last_h, cur_h)
                last_l = max(last_l, cur_l)
            elif direction == "down":
                last_h = min(last_h, cur_h)
                last_l = min(last_l, cur_l)
            else:
                raise ValueError
178
            seq.append({"start_dt": last['start_dt'], "end_dt": row['end_dt'], "high": last_h, "low": last_l})
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
        else:
            seq.append(row)
    return seq


def is_valid_xd(bi_seq1, bi_seq2, bi_seq3):
    """判断线段标记是否有效(第二个线段标记)

    :param bi_seq1: list of dict
        第一个线段标记到第二个线段标记之间的笔序列
    :param bi_seq2:
        第二个线段标记到第三个线段标记之间的笔序列
    :param bi_seq3:
        第三个线段标记之后的笔序列
    :return:
    """
195
    assert bi_seq2[0]['dt'] == bi_seq1[-1]['dt'] and bi_seq3[0]['dt'] == bi_seq2[-1]['dt']
196 197

    standard_bi_seq1 = make_standard_seq(bi_seq1)
198 199
    if len(standard_bi_seq1) == 0 or len(bi_seq2) < 4:
        return False
200 201 202 203 204 205 206 207 208 209 210 211 212

    # 第一种情况(向下线段)
    if bi_seq2[0]['fx_mark'] == 'd' and bi_seq2[1]['bi'] >= standard_bi_seq1[-1]['low']:
        if bi_seq2[-1]['bi'] < bi_seq2[1]['bi']:
            return False

    # 第一种情况(向上线段)
    if bi_seq2[0]['fx_mark'] == 'g' and bi_seq2[1]['bi'] <= standard_bi_seq1[-1]['high']:
        if bi_seq2[-1]['bi'] > bi_seq2[1]['bi']:
            return False

    # 第二种情况(向下线段)
    if bi_seq2[0]['fx_mark'] == 'd' and bi_seq2[1]['bi'] < standard_bi_seq1[-1]['low']:
213
        bi_seq2.extend(bi_seq3[1:])
214
        standard_bi_seq2 = make_standard_seq(bi_seq2)
215 216 217
        if len(standard_bi_seq2) < 3:
            return False

218
        standard_bi_seq2_g = []
219
        for i in range(1, len(standard_bi_seq2) - 1):
220 221 222
            bi1, bi2, bi3 = standard_bi_seq2[i-1: i+2]
            if bi1['high'] < bi2['high'] > bi3['high']:
                standard_bi_seq2_g.append(bi2)
223 224 225 226 227

                # 如果特征序列顶分型最小值小于底分型,返回 False
                if min([x['low'] for x in standard_bi_seq2[i-1: i+2]]) < bi_seq2[0]['bi']:
                    return False

228 229 230 231 232
        if len(standard_bi_seq2_g) == 0:
            return False

    # 第二种情况(向上线段)
    if bi_seq2[0]['fx_mark'] == 'g' and bi_seq2[1]['bi'] > standard_bi_seq1[-1]['high']:
233
        bi_seq2.extend(bi_seq3[1:])
234
        standard_bi_seq2 = make_standard_seq(bi_seq2)
235 236 237
        if len(standard_bi_seq2) < 3:
            return False

238
        standard_bi_seq2_d = []
239
        for i in range(1, len(standard_bi_seq2) - 1):
240 241 242
            bi1, bi2, bi3 = standard_bi_seq2[i-1: i+2]
            if bi1['low'] > bi2['low'] < bi3['low']:
                standard_bi_seq2_d.append(bi2)
243 244 245 246

                # 如果特征序列的底分型最大值大于顶分型,返回 False
                if max([x['high'] for x in standard_bi_seq2[i-1: i+2]]) > bi_seq2[0]['bi']:
                    return False
247 248 249 250
        if len(standard_bi_seq2_d) == 0:
            return False
    return True

Z
zengbin93 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293

def get_potential_xd(bi_points):
    """获取潜在线段标记点

    :param bi_points: list of dict
        笔标记点
    :return: list of dict
        潜在线段标记点
    """
    xd_p = []
    bi_d = [x for x in bi_points if x['fx_mark'] == 'd']
    bi_g = [x for x in bi_points if x['fx_mark'] == 'g']
    for i in range(1, len(bi_d) - 1):
        d1, d2, d3 = bi_d[i - 1: i + 2]
        if d1['bi'] > d2['bi'] < d3['bi']:
            xd_p.append(d2)
    for j in range(1, len(bi_g) - 1):
        g1, g2, g3 = bi_g[j - 1: j + 2]
        if g1['bi'] < g2['bi'] > g3['bi']:
            xd_p.append(g2)

    xd_p = sorted(xd_p, key=lambda x: x['dt'], reverse=False)
    return xd_p


def handle_last_xd(bi_points):
    """处理当下段

    当下段是指当下进行中的无法确认完成的线段,对于操作而言,必须在当下对其进行分析,判断是延续还是转折。

    :param bi_points: list of dict
        最近一个线段标记后面的全部笔标记。在这些笔标记中可能存在 1个、2个或3个需要需要确认的线段标记。
    :return: list of dict
        返回判断结果
    """
    # step 1. 获取潜在分段标记点
    xd_p = get_potential_xd(bi_points)
    if len(xd_p) == 0:
        if bi_points[0]['fx_mark'] != bi_points[-1]['fx_mark']:
            bi_points.pop(-1)



Z
zengbin93 已提交
294
class KlineAnalyze:
Z
zengbin93 已提交
295
    def __init__(self, kline, name="本级别", bi_mode="old", max_raw_len=10000, ma_params=(5, 20, 120), verbose=False):
296 297 298 299
        """

        :param kline: list or pd.DataFrame
        :param name: str
Z
zengbin93 已提交
300 301
        :param bi_mode: str
            new 新笔;old 老笔;默认值为 old
302 303
        :param max_raw_len: int
            原始K线序列的最大长度
Z
zengbin93 已提交
304 305
        :param ma_params: tuple of int
            均线系统参数
306 307
        :param verbose: bool
        """
Z
zengbin93 已提交
308 309
        self.name = name
        self.verbose = verbose
Z
zengbin93 已提交
310
        self.bi_mode = bi_mode
311
        self.max_raw_len = max_raw_len
312
        self.ma_params = ma_params
313 314
        self.kline_raw = []  # 原始K线序列
        self.kline_new = []  # 去除包含关系的K线序列
Z
zengbin93 已提交
315

316 317 318 319
        # 辅助技术指标
        self.ma = []
        self.macd = []

Z
zengbin93 已提交
320 321 322 323 324
        # 分型、笔、线段
        self.fx_list = []
        self.bi_list = []
        self.xd_list = []

Z
zengbin93 已提交
325
        # 根据输入K线初始化
Z
zengbin93 已提交
326 327
        if isinstance(kline, pd.DataFrame):
            columns = kline.columns.to_list()
Z
zengbin93 已提交
328
            self.kline_raw = [{k: v for k, v in zip(columns, row)} for row in kline.values]
Z
zengbin93 已提交
329
        else:
Z
zengbin93 已提交
330
            self.kline_raw = kline
Z
zengbin93 已提交
331

332
        self.kline_raw = self.kline_raw[-self.max_raw_len:]
Z
zengbin93 已提交
333 334 335 336 337
        self.symbol = self.kline_raw[0]['symbol']
        self.start_dt = self.kline_raw[0]['dt']
        self.end_dt = self.kline_raw[-1]['dt']
        self.latest_price = self.kline_raw[-1]['close']

338
        self._update_ta()
Z
zengbin93 已提交
339 340 341 342
        self._update_kline_new()
        self._update_fx_list()
        self._update_bi_list()
        self._update_xd_list()
Z
zengbin93 已提交
343

344 345 346 347 348 349
    def _update_ta(self):
        """更新辅助技术指标"""
        if not self.ma:
            ma_temp = dict()
            close_ = np.array([x["close"] for x in self.kline_raw], dtype=np.double)
            for p in self.ma_params:
Z
zengbin93 已提交
350
                ma_temp['ma%i' % p] = ta.SMA(close_, p)
351 352 353 354 355 356 357 358 359

            for i in range(len(self.kline_raw)):
                ma_ = {'ma%i' % p: ma_temp['ma%i' % p][i] for p in self.ma_params}
                ma_.update({"dt": self.kline_raw[i]['dt']})
                self.ma.append(ma_)
        else:
            ma_ = {'ma%i' % p: sum([x['close'] for x in self.kline_raw[-p:]]) / p
                   for p in self.ma_params}
            ma_.update({"dt": self.kline_raw[-1]['dt']})
360 361 362
            if self.verbose:
                print("ma new: %s" % str(ma_))

363 364 365 366 367
            if self.kline_raw[-2]['dt'] == self.ma[-1]['dt']:
                self.ma.append(ma_)
            else:
                self.ma[-1] = ma_

368
        assert self.ma[-2]['dt'] == self.kline_raw[-2]['dt']
369 370 371 372

        if not self.macd:
            close_ = np.array([x["close"] for x in self.kline_raw], dtype=np.double)
            # m1 is diff; m2 is dea; m3 is macd
Z
zengbin93 已提交
373
            m1, m2, m3 = ta.MACD(close_, fastperiod=12, slowperiod=26, signalperiod=9)
374 375 376 377 378 379 380 381 382 383
            for i in range(len(self.kline_raw)):
                self.macd.append({
                    "dt": self.kline_raw[i]['dt'],
                    "diff": m1[i],
                    "dea": m2[i],
                    "macd": m3[i]
                })
        else:
            close_ = np.array([x["close"] for x in self.kline_raw[-200:]], dtype=np.double)
            # m1 is diff; m2 is dea; m3 is macd
Z
zengbin93 已提交
384
            m1, m2, m3 = ta.MACD(close_, fastperiod=12, slowperiod=26, signalperiod=9)
385
            macd_ = {
386 387 388 389 390
                "dt": self.kline_raw[-1]['dt'],
                "diff": m1[-1],
                "dea": m2[-1],
                "macd": m3[-1]
            }
391 392 393 394
            if self.verbose:
                print("macd new: %s" % str(macd_))

            if self.kline_raw[-2]['dt'] == self.macd[-1]['dt']:
395 396 397 398
                self.macd.append(macd_)
            else:
                self.macd[-1] = macd_

399 400
        assert self.macd[-2]['dt'] == self.kline_raw[-2]['dt']

Z
zengbin93 已提交
401
    def _update_kline_new(self):
Z
zengbin93 已提交
402
        """更新去除包含关系的K线序列"""
Z
zengbin93 已提交
403
        if len(self.kline_new) == 0:
Z
zengbin93 已提交
404
            for x in self.kline_raw[:4]:
Z
zengbin93 已提交
405
                self.kline_new.append(dict(x))
Z
zengbin93 已提交
406 407

        # 新K线只会对最后一个去除包含关系K线的结果产生影响
Z
zengbin93 已提交
408
        self.kline_new = self.kline_new[:-2]
Z
zengbin93 已提交
409 410 411 412 413
        if len(self.kline_new) <= 4:
            right_k = [x for x in self.kline_raw if x['dt'] > self.kline_new[-1]['dt']]
        else:
            right_k = [x for x in self.kline_raw[-100:] if x['dt'] > self.kline_new[-1]['dt']]

Z
zengbin93 已提交
414 415
        if len(right_k) == 0:
            return
Z
zengbin93 已提交
416 417 418 419

        for k in right_k:
            k = dict(k)
            last_kn = self.kline_new[-1]
Z
zengbin93 已提交
420 421 422 423
            if self.kline_new[-1]['high'] > self.kline_new[-2]['high']:
                direction = "up"
            else:
                direction = "down"
Z
zengbin93 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448

            # 判断是否存在包含关系
            cur_h, cur_l = k['high'], k['low']
            last_h, last_l = last_kn['high'], last_kn['low']
            if (cur_h <= last_h and cur_l >= last_l) or (cur_h >= last_h and cur_l <= last_l):
                self.kline_new.pop(-1)
                # 有包含关系,按方向分别处理
                if direction == "up":
                    last_h = max(last_h, cur_h)
                    last_l = max(last_l, cur_l)
                elif direction == "down":
                    last_h = min(last_h, cur_h)
                    last_l = min(last_l, cur_l)
                else:
                    raise ValueError

                k.update({"high": last_h, "low": last_l})
                # 保留红绿不变
                if k['open'] >= k['close']:
                    k.update({"open": last_h, "close": last_l})
                else:
                    k.update({"open": last_l, "close": last_h})
            self.kline_new.append(k)

    def _update_fx_list(self):
Z
zengbin93 已提交
449
        """更新分型序列"""
Z
zengbin93 已提交
450 451 452 453 454 455 456
        if len(self.kline_new) < 3:
            return

        self.fx_list = self.fx_list[:-1]
        if len(self.fx_list) == 0:
            kn = self.kline_new
        else:
457
            kn = [x for x in self.kline_new[-100:] if x['dt'] >= self.fx_list[-1]['dt']]
Z
zengbin93 已提交
458 459

        i = 1
460 461
        while i <= len(kn) - 2:
            k1, k2, k3 = kn[i - 1: i + 2]
Z
zengbin93 已提交
462 463 464 465 466 467
            fx_elements = [k1, k2, k3]
            if has_gap(k1, k2):
                fx_elements.pop(0)

            if has_gap(k2, k3):
                fx_elements.pop(-1)
Z
zengbin93 已提交
468 469 470

            if k1['high'] < k2['high'] > k3['high']:
                if self.verbose:
Z
zengbin93 已提交
471
                    print("顶分型:{} - {} - {}".format(k1['dt'], k2['dt'], k3['dt']))
Z
zengbin93 已提交
472 473 474 475
                fx = {
                    "dt": k2['dt'],
                    "fx_mark": "g",
                    "fx": k2['high'],
Z
zengbin93 已提交
476 477
                    "start_dt": k1['dt'],
                    "end_dt": k3['dt'],
Z
zengbin93 已提交
478
                    "fx_high": k2['high'],
Z
zengbin93 已提交
479
                    "fx_low": min([x['low'] for x in fx_elements]),
Z
zengbin93 已提交
480 481 482 483 484
                }
                self.fx_list.append(fx)

            elif k1['low'] > k2['low'] < k3['low']:
                if self.verbose:
Z
zengbin93 已提交
485
                    print("底分型:{} - {} - {}".format(k1['dt'], k2['dt'], k3['dt']))
Z
zengbin93 已提交
486 487 488 489
                fx = {
                    "dt": k2['dt'],
                    "fx_mark": "d",
                    "fx": k2['low'],
Z
zengbin93 已提交
490 491 492
                    "start_dt": k1['dt'],
                    "end_dt": k3['dt'],
                    "fx_high": max([x['high'] for x in fx_elements]),
Z
zengbin93 已提交
493 494 495 496 497 498
                    "fx_low": k2['low'],
                }
                self.fx_list.append(fx)

            else:
                if self.verbose:
Z
zengbin93 已提交
499
                    print("无分型:{} - {} - {}".format(k1['dt'], k2['dt'], k3['dt']))
Z
zengbin93 已提交
500 501 502
            i += 1

    def _update_bi_list(self):
503
        """更新笔序列"""
Z
zengbin93 已提交
504 505 506
        if len(self.fx_list) < 2:
            return

507
        self.bi_list = self.bi_list[:-2]
Z
zengbin93 已提交
508
        if len(self.bi_list) == 0:
Z
zengbin93 已提交
509
            for fx in self.fx_list[:2]:
Z
zengbin93 已提交
510 511 512 513
                bi = dict(fx)
                bi['bi'] = bi.pop('fx')
                self.bi_list.append(bi)

Z
zengbin93 已提交
514 515
        if len(self.bi_list) <= 2:
            right_fx = [x for x in self.fx_list if x['dt'] > self.bi_list[-1]['dt']]
Z
zengbin93 已提交
516 517 518 519 520 521
            if self.bi_mode == "old":
                right_kn = [x for x in self.kline_new if x['dt'] >= self.bi_list[-1]['dt']]
            elif self.bi_mode == 'new':
                right_kn = [x for x in self.kline_raw if x['dt'] >= self.bi_list[-1]['dt']]
            else:
                raise ValueError
Z
zengbin93 已提交
522
        else:
Z
zengbin93 已提交
523 524 525 526 527 528 529
            right_fx = [x for x in self.fx_list[-50:] if x['dt'] > self.bi_list[-1]['dt']]
            if self.bi_mode == "old":
                right_kn = [x for x in self.kline_new[-300:] if x['dt'] >= self.bi_list[-1]['dt']]
            elif self.bi_mode == 'new':
                right_kn = [x for x in self.kline_raw[-300:] if x['dt'] >= self.bi_list[-1]['dt']]
            else:
                raise ValueError
Z
zengbin93 已提交
530 531 532 533 534 535 536 537 538

        for fx in right_fx:
            last_bi = self.bi_list[-1]
            bi = dict(fx)
            bi['bi'] = bi.pop('fx')
            if last_bi['fx_mark'] == fx['fx_mark']:
                if (last_bi['fx_mark'] == 'g' and last_bi['bi'] < bi['bi']) \
                        or (last_bi['fx_mark'] == 'd' and last_bi['bi'] > bi['bi']):
                    if self.verbose:
Z
zengbin93 已提交
539
                        print("笔标记移动:from {} to {}".format(self.bi_list[-1], bi))
Z
zengbin93 已提交
540 541
                    self.bi_list[-1] = bi
            else:
Z
zengbin93 已提交
542 543 544 545 546 547 548 549 550 551 552 553
                kn_inside = [x for x in right_kn if last_bi['end_dt'] < x['dt'] < bi['start_dt']]
                if len(kn_inside) <= 0:
                    continue

                # 确保相邻两个顶底之间不存在包含关系
                if (last_bi['fx_mark'] == 'g' and bi['fx_low'] < last_bi['fx_low']
                    and bi['fx_high'] < last_bi['fx_high']) or \
                        (last_bi['fx_mark'] == 'd' and bi['fx_high'] > last_bi['fx_high']
                         and bi['fx_low'] > last_bi['fx_low']):
                    if self.verbose:
                        print("新增笔标记:{}".format(bi))
                    self.bi_list.append(bi)
Z
zengbin93 已提交
554

555 556 557 558 559
        # if (self.bi_list[-1]['fx_mark'] == 'd' and self.kline_new[-1]['low'] < self.bi_list[-1]['bi']) \
        #         or (self.bi_list[-1]['fx_mark'] == 'g' and self.kline_new[-1]['high'] > self.bi_list[-1]['bi']):
        #     if self.verbose:
        #         print("最后一个笔标记无效,{}".format(self.bi_list[-1]))
        #     self.bi_list.pop(-1)
Z
zengbin93 已提交
560

561 562 563 564
    def _update_xd_list_v1(self):
        """更新线段序列"""
        if len(self.bi_list) < 4:
            return
Z
zengbin93 已提交
565

Z
zengbin93 已提交
566
        self.xd_list = self.xd_list[:-2]
567 568 569 570 571 572 573 574
        if len(self.xd_list) == 0:
            for i in range(3):
                xd = dict(self.bi_list[i])
                xd['xd'] = xd.pop('bi')
                self.xd_list.append(xd)

        if len(self.xd_list) <= 3:
            right_bi = [x for x in self.bi_list if x['dt'] >= self.xd_list[-1]['dt']]
Z
zengbin93 已提交
575
        else:
576
            right_bi = [x for x in self.bi_list[-200:] if x['dt'] >= self.xd_list[-1]['dt']]
Z
zengbin93 已提交
577 578

        xd_p = get_potential_xd(right_bi)
579 580 581 582 583 584 585 586 587 588
        for xp in xd_p:
            xd = dict(xp)
            xd['xd'] = xd.pop('bi')
            last_xd = self.xd_list[-1]
            if last_xd['fx_mark'] == xd['fx_mark']:
                if (last_xd['fx_mark'] == 'd' and last_xd['xd'] > xd['xd']) \
                        or (last_xd['fx_mark'] == 'g' and last_xd['xd'] < xd['xd']):
                    if self.verbose:
                        print("更新线段标记:from {} to {}".format(last_xd, xd))
                    self.xd_list[-1] = xd
Z
zengbin93 已提交
589
            else:
590 591 592
                if (last_xd['fx_mark'] == 'd' and last_xd['xd'] > xd['xd']) \
                        or (last_xd['fx_mark'] == 'g' and last_xd['xd'] < xd['xd']):
                    continue
Z
zengbin93 已提交
593

594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
                bi_inside = [x for x in right_bi if last_xd['dt'] <= x['dt'] <= xd['dt']]
                if len(bi_inside) < 4:
                    if self.verbose:
                        print("{} - {} 之间笔标记数量少于4,跳过".format(last_xd['dt'], xd['dt']))
                    continue
                else:
                    if len(bi_inside) > 4:
                        if self.verbose:
                            print("新增线段标记(笔标记数量大于4):{}".format(xd))
                        self.xd_list.append(xd)
                    else:
                        bi_r = [x for x in right_bi if x['dt'] >= xd['dt']]
                        assert bi_r[1]['fx_mark'] == bi_inside[-2]['fx_mark']
                        # 第一种情况:没有缺口
                        if (bi_r[1]['fx_mark'] == "g" and bi_r[1]['bi'] > bi_inside[-3]['bi']) \
                                or (bi_r[1]['fx_mark'] == "d" and bi_r[1]['bi'] < bi_inside[-3]['bi']):
                            if self.verbose:
                                print("新增线段标记(第一种情况):{}".format(xd))
                            self.xd_list.append(xd)
                        # 第二种情况:有缺口
                        else:
                            if (bi_r[1]['fx_mark'] == "g" and bi_r[1]['bi'] < bi_inside[-2]['bi']) \
                                    or (bi_r[1]['fx_mark'] == "d" and bi_r[1]['bi'] > bi_inside[-2]['bi']):
                                if self.verbose:
                                    print("新增线段标记(第二种情况):{}".format(xd))
                                self.xd_list.append(xd)

621
    def _xd_after_process(self):
Z
zengbin93 已提交
622
        """线段标记后处理,使用标准特征序列判断线段标记是否成立"""
623
        if not len(self.xd_list) > 4:
Z
zengbin93 已提交
624 625
            return

626 627 628 629 630 631
        keep_xd_index = []
        for i in range(1, len(self.xd_list)-2):
            xd1, xd2, xd3, xd4 = self.xd_list[i-1: i+3]
            bi_seq1 = [x for x in self.bi_list if xd2['dt'] >= x['dt'] >= xd1['dt']]
            bi_seq2 = [x for x in self.bi_list if xd3['dt'] >= x['dt'] >= xd2['dt']]
            bi_seq3 = [x for x in self.bi_list if xd4['dt'] >= x['dt'] >= xd3['dt']]
Z
zengbin93 已提交
632 633 634
            if len(bi_seq1) == 0 or len(bi_seq2) == 0 or len(bi_seq3) == 0:
                continue

635 636 637 638 639 640 641
            if is_valid_xd(bi_seq1, bi_seq2, bi_seq3):
                keep_xd_index.append(i)

        # 处理最后一个确定的线段标记
        bi_seq1 = [x for x in self.bi_list if self.xd_list[-2]['dt'] >= x['dt'] >= self.xd_list[-3]['dt']]
        bi_seq2 = [x for x in self.bi_list if self.xd_list[-1]['dt'] >= x['dt'] >= self.xd_list[-2]['dt']]
        bi_seq3 = [x for x in self.bi_list if x['dt'] >= self.xd_list[-1]['dt']]
Z
zengbin93 已提交
642 643 644
        if not (len(bi_seq1) == 0 or len(bi_seq2) == 0 or len(bi_seq3) == 0):
            if is_valid_xd(bi_seq1, bi_seq2, bi_seq3):
                keep_xd_index.append(len(self.xd_list)-2)
645 646 647 648 649

        new_xd_list = []
        for j in keep_xd_index:
            if not new_xd_list:
                new_xd_list.append(self.xd_list[j])
Z
zengbin93 已提交
650
            else:
651 652 653 654
                if new_xd_list[-1]['fx_mark'] == self.xd_list[j]['fx_mark']:
                    if (new_xd_list[-1]['fx_mark'] == 'd' and new_xd_list[-1]['xd'] > self.xd_list[j]['xd']) \
                            or (new_xd_list[-1]['fx_mark'] == 'g' and new_xd_list[-1]['xd'] < self.xd_list[j]['xd']):
                        new_xd_list[-1] = self.xd_list[j]
Z
zengbin93 已提交
655
                else:
656 657
                    new_xd_list.append(self.xd_list[j])
        self.xd_list = new_xd_list
658 659 660

    def _update_xd_list(self):
        self._update_xd_list_v1()
661
        self._xd_after_process()
662

Z
zengbin93 已提交
663 664 665 666 667 668 669 670 671 672 673
    def update(self, k):
        """更新分析结果

        :param k: dict
            单根K线对象,样例如下
            {'symbol': '000001.SH',
             'dt': Timestamp('2020-07-16 15:00:00'),
             'open': 3356.11,
             'close': 3210.1,
             'high': 3373.53,
             'low': 3209.76,
Z
zengbin93 已提交
674
             'vol': 486366915.0}
Z
zengbin93 已提交
675
        """
Z
zengbin93 已提交
676 677
        if self.verbose:
            print("=" * 100)
Z
zengbin93 已提交
678
            print("输入新K线:{}".format(k))
Z
zengbin93 已提交
679
        if not self.kline_raw or k['open'] != self.kline_raw[-1]['open']:
Z
zengbin93 已提交
680 681 682
            self.kline_raw.append(k)
        else:
            if self.verbose:
Z
zengbin93 已提交
683
                print("输入K线处于未完成状态,更新:replace {} with {}".format(self.kline_raw[-1], k))
Z
zengbin93 已提交
684 685
            self.kline_raw[-1] = k

686
        self._update_ta()
Z
zengbin93 已提交
687 688 689
        self._update_kline_new()
        self._update_fx_list()
        self._update_bi_list()
Z
zengbin93 已提交
690 691
        self._update_xd_list()

Z
zengbin93 已提交
692 693 694
        self.end_dt = self.kline_raw[-1]['dt']
        self.latest_price = self.kline_raw[-1]['close']

695
        # 根据最大原始K线序列长度限制分析结果长度
Z
zengbin93 已提交
696 697
        if len(self.kline_raw) > self.max_raw_len:
            self.kline_raw = self.kline_raw[-self.max_raw_len:]
Z
zengbin93 已提交
698
            self.kline_new = self.kline_new[-self.max_raw_len:]
699 700
            self.ma = self.ma[-self.max_raw_len:]
            self.macd = self.macd[-self.max_raw_len:]
Z
zengbin93 已提交
701 702 703 704 705 706 707 708
            last_dt = self.kline_new[0]['dt']
            self.fx_list = [x for x in self.fx_list if x['dt'] > last_dt]
            self.bi_list = [x for x in self.bi_list if x['dt'] > last_dt]
            self.xd_list = [x for x in self.xd_list if x['dt'] > last_dt]

            # self.fx_list = self.fx_list[-(self.max_raw_len // 2):]
            # self.bi_list = self.bi_list[-(self.max_raw_len // 4):]
            # self.xd_list = self.xd_list[-(self.max_raw_len // 8):]
709

Z
zengbin93 已提交
710 711
        if self.verbose:
            print("更新结束\n\n")
Z
zengbin93 已提交
712

Z
zengbin93 已提交
713
    def to_df(self, ma_params=(5, 20), use_macd=False, max_count=1000, mode="raw"):
Z
zengbin93 已提交
714 715 716 717 718 719
        """整理成 df 输出

        :param ma_params: tuple of int
            均线系统参数
        :param use_macd: bool
        :param max_count: int
Z
zengbin93 已提交
720 721
        :param mode: str
            使用K线类型, raw = 原始K线,new = 去除包含关系的K线
Z
zengbin93 已提交
722 723
        :return: pd.DataFrame
        """
Z
zengbin93 已提交
724 725 726 727 728 729 730
        if mode == "raw":
            bars = self.kline_raw[-max_count:]
        elif mode == "new":
            bars = self.kline_raw[-max_count:]
        else:
            raise ValueError

Z
zengbin93 已提交
731 732 733
        fx_list = {x["dt"]: {"fx_mark": x["fx_mark"], "fx": x['fx']} for x in self.fx_list[-(max_count // 2):]}
        bi_list = {x["dt"]: {"fx_mark": x["fx_mark"], "bi": x['bi']} for x in self.bi_list[-(max_count // 4):]}
        xd_list = {x["dt"]: {"fx_mark": x["fx_mark"], "xd": x['xd']} for x in self.xd_list[-(max_count // 8):]}
Z
zengbin93 已提交
734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
        results = []
        for k in bars:
            k['fx_mark'], k['fx'], k['bi'], k['xd'] = "o", None, None, None
            fx_ = fx_list.get(k['dt'], None)
            bi_ = bi_list.get(k['dt'], None)
            xd_ = xd_list.get(k['dt'], None)
            if fx_:
                k['fx_mark'] = fx_["fx_mark"]
                k['fx'] = fx_["fx"]

            if bi_:
                k['bi'] = bi_["bi"]

            if xd_:
                k['xd'] = xd_["xd"]

            results.append(k)
        df = pd.DataFrame(results)
Z
zengbin93 已提交
752 753
        for p in ma_params:
            df.loc[:, "ma{}".format(p)] = ta.SMA(df.close.values, p)
Z
zengbin93 已提交
754
        if use_macd:
Z
zengbin93 已提交
755 756 757 758
            diff, dea, macd = ta.MACD(df.close.values)
            df.loc[:, "diff"] = diff
            df.loc[:, "dea"] = diff
            df.loc[:, "macd"] = diff
Z
zengbin93 已提交
759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775
        return df

    def to_image(self, file_image, mav=(5, 20, 120, 250), max_k_count=1000, dpi=50):
        """保存成图片

        :param file_image: str
            图片名称,支持 jpg/png/svg 格式,注意后缀
        :param mav: tuple of int
            均线系统参数
        :param max_k_count: int
            设定最大K线数量,这个值越大,生成的图片越长
        :param dpi: int
            图片分辨率
        :return:
        """
        plot_ka(self, file_image=file_image, mav=mav, max_k_count=max_k_count, dpi=dpi)

776
    def is_bei_chi(self, zs1, zs2, mode="bi", adjust=0.9, last_index: int = None):
777
        """判断 zs1 对 zs2 是否有背驰
Z
zengbin93 已提交
778

779 780 781 782 783 784 785 786 787 788 789 790 791 792 793
        注意:力度的比较,并没有要求两段走势方向一致;但是如果两段走势之间存在包含关系,这样的力度比较是没有意义的。

        :param zs1: dict
            用于比较的走势,通常是最近的走势,示例如下:
            zs1 = {"start_dt": "2020-02-20 11:30:00", "end_dt": "2020-02-20 14:30:00", "direction": "up"}
        :param zs2: dict
            被比较的走势,通常是较前的走势,示例如下:
            zs2 = {"start_dt": "2020-02-21 11:30:00", "end_dt": "2020-02-21 14:30:00", "direction": "down"}
        :param mode: str
            default `bi`, optional value [`xd`, `bi`]
            xd  判断两个线段之间是否存在背驰
            bi  判断两笔之间是否存在背驰
        :param adjust: float
            调整 zs2 的力度,建议设置范围在 0.6 ~ 1.0 之间,默认设置为 0.9;
            其作用是确保 zs1 相比于 zs2 的力度足够小。
794 795
        :param last_index: int
            在比较最后一个走势的时候,可以设置这个参数来提升速度,相当于只对 last_index 后面的K线进行力度比较
796 797 798 799 800 801 802 803
        :return: bool
        """
        assert zs1["start_dt"] > zs2["end_dt"], "zs1 必须是最近的走势,用于比较;zs2 必须是较前的走势,被比较。"
        assert zs1["start_dt"] < zs1["end_dt"], "走势的时间区间定义错误,必须满足 start_dt < end_dt"
        assert zs2["start_dt"] < zs2["end_dt"], "走势的时间区间定义错误,必须满足 start_dt < end_dt"

        min_dt = min(zs1["start_dt"], zs2["start_dt"])
        max_dt = max(zs1["end_dt"], zs2["end_dt"])
804 805 806 807 808
        if last_index:
            macd = self.macd[-last_index:]
        else:
            macd = self.macd
        macd_ = [x for x in macd if x['dt'] >= min_dt]
Z
zengbin93 已提交
809
        macd_ = [x for x in macd_ if max_dt >= x['dt']]
810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842
        k1 = [x for x in macd_ if zs1["end_dt"] >= x['dt'] >= zs1["start_dt"]]
        k2 = [x for x in macd_ if zs2["end_dt"] >= x['dt'] >= zs2["start_dt"]]

        bc = False
        if mode == 'bi':
            macd_sum1 = sum([abs(x['macd']) for x in k1])
            macd_sum2 = sum([abs(x['macd']) for x in k2])
            # print("bi: ", macd_sum1, macd_sum2)
            if macd_sum1 < macd_sum2 * adjust:
                bc = True

        elif mode == 'xd':
            assert zs1['direction'] in ['down', 'up'], "走势的 direction 定义错误,可取值为 up 或 down"
            assert zs2['direction'] in ['down', 'up'], "走势的 direction 定义错误,可取值为 up 或 down"

            if zs1['direction'] == "down":
                macd_sum1 = sum([abs(x['macd']) for x in k1 if x['macd'] < 0])
            else:
                macd_sum1 = sum([abs(x['macd']) for x in k1 if x['macd'] > 0])

            if zs2['direction'] == "down":
                macd_sum2 = sum([abs(x['macd']) for x in k2 if x['macd'] < 0])
            else:
                macd_sum2 = sum([abs(x['macd']) for x in k2 if x['macd'] > 0])

            # print("xd: ", macd_sum1, macd_sum2)
            if macd_sum1 < macd_sum2 * adjust:
                bc = True

        else:
            raise ValueError("mode value error")

        return bc
Z
zengbin93 已提交
843

Z
zengbin93 已提交
844
    def get_sub_section(self, start_dt: datetime, end_dt: datetime, mode="bi", is_last=True):
Z
zengbin93 已提交
845 846 847 848 849 850 851 852
        """获取子区间

        :param start_dt: datetime
            子区间开始时间
        :param end_dt: datetime
            子区间结束时间
        :param mode: str
            需要获取的子区间对象类型,可取值 ['k', 'fx', 'bi', 'xd']
Z
zengbin93 已提交
853 854
        :param is_last: bool
            是否是最近一段子区间
Z
zengbin93 已提交
855 856
        :return: list of dict
        """
Z
zengbin93 已提交
857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880
        if mode == "kn":
            if is_last:
                points = self.kline_new[-200:]
            else:
                points = self.kline_new
        elif mode == "fx":
            if is_last:
                points = self.fx_list[-100:]
            else:
                points = self.fx_list
        elif mode == "bi":
            if is_last:
                points = self.bi_list[-50:]
            else:
                points = self.bi_list
        elif mode == "xd":
            if is_last:
                points = self.xd_list[-30:]
            else:
                points = self.xd_list
        else:
            raise ValueError

        return [x for x in points if end_dt >= x['dt'] >= start_dt]