analyze.py 44.8 KB
Newer Older
Z
zengbin93 已提交
1 2
# coding: utf-8

Z
zengbin93 已提交
3
import warnings
4

Z
zengbin93 已提交
5 6 7 8
try:
    import talib as ta
except ImportError:
    from czsc import ta
9

Z
zengbin93 已提交
10 11 12
    ta_lib_hint = "没有安装 ta-lib !!! 请到 https://www.lfd.uci.edu/~gohlke/pythonlibs/#ta-lib " \
                  "下载对应版本安装,预计分析速度提升2倍"
    warnings.warn(ta_lib_hint)
Z
zengbin93 已提交
13
import pandas as pd
14
import numpy as np
Z
zengbin93 已提交
15
from czsc.utils import ka_to_image
Z
zengbin93 已提交
16

Z
zengbin93 已提交
17 18
def find_zs(points):
    """输入笔或线段标记点,输出中枢识别结果"""
Z
zengbin93 已提交
19
    if len(points) < 5:
Z
zengbin93 已提交
20 21 22
        return []

    # 当输入为笔的标记点时,新增 xd 值
23
    for j, x in enumerate(points):
Z
zengbin93 已提交
24
        if x.get("bi", 0):
25
            points[j]['xd'] = x["bi"]
Z
zengbin93 已提交
26

27 28 29 30 31 32 33 34 35 36
    def __get_zn(zn_points_):
        """把与中枢方向一致的次级别走势类型称为Z走势段,按中枢中的时间顺序,
        分别记为Zn等,而相应的高、低点分别记为gn、dn"""
        if len(zn_points_) % 2 != 0:
            zn_points_ = zn_points_[:-1]

        if zn_points_[0]['fx_mark'] == "d":
            z_direction = "up"
        else:
            z_direction = "down"
37 38 39 40 41 42 43 44 45 46 47 48

        zn = []
        for i in range(0, len(zn_points_), 2):
            zn_ = {
                "start_dt": zn_points_[i]['dt'],
                "end_dt": zn_points_[i + 1]['dt'],
                "high": max(zn_points_[i]['xd'], zn_points_[i + 1]['xd']),
                "low": min(zn_points_[i]['xd'], zn_points_[i + 1]['xd']),
                "direction": z_direction
            }
            zn_['mid'] = zn_['low'] + (zn_['high'] - zn_['low']) / 2
            zn.append(zn_)
49 50
        return zn

Z
zengbin93 已提交
51 52 53 54 55 56 57 58 59
    k_xd = points
    k_zs = []
    zs_xd = []

    for i in range(len(k_xd)):
        if len(zs_xd) < 5:
            zs_xd.append(k_xd[i])
            continue
        xd_p = k_xd[i]
60 61
        zs_d = max([x['xd'] for x in zs_xd[:4] if x['fx_mark'] == 'd'])
        zs_g = min([x['xd'] for x in zs_xd[:4] if x['fx_mark'] == 'g'])
Z
zengbin93 已提交
62 63 64 65 66
        if zs_g <= zs_d:
            zs_xd.append(k_xd[i])
            zs_xd.pop(0)
            continue

Z
zengbin93 已提交
67
        # 定义四个指标,GG=max(gn),G=min(gn),D=max(dn),DD=min(dn),n遍历中枢中所有Zn。
68
        # 定义ZG=min(g1、g2), ZD=max(d1、d2),显然,[ZD,ZG]就是缠中说禅走势中枢的区间
Z
zengbin93 已提交
69
        if xd_p['fx_mark'] == "d" and xd_p['xd'] > zs_g:
70
            zn_points = zs_xd[3:]
Z
zengbin93 已提交
71 72 73 74 75 76 77 78
            # 线段在中枢上方结束,形成三买
            k_zs.append({
                'ZD': zs_d,
                "ZG": zs_g,
                'G': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'GG': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'D': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
                'DD': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
79
                'start_point': zs_xd[1],
Z
zengbin93 已提交
80
                'end_point': zs_xd[-2],
81
                "zn": __get_zn(zn_points),
Z
zengbin93 已提交
82 83 84
                "points": zs_xd,
                "third_buy": xd_p
            })
85
            zs_xd = []
Z
zengbin93 已提交
86
        elif xd_p['fx_mark'] == "g" and xd_p['xd'] < zs_d:
87
            zn_points = zs_xd[3:]
Z
zengbin93 已提交
88 89 90 91 92 93 94 95
            # 线段在中枢下方结束,形成三卖
            k_zs.append({
                'ZD': zs_d,
                "ZG": zs_g,
                'G': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'GG': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'D': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
                'DD': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
96
                'start_point': zs_xd[1],
Z
zengbin93 已提交
97
                'end_point': zs_xd[-2],
Z
zengbin93 已提交
98
                "points": zs_xd,
99
                "zn": __get_zn(zn_points),
Z
zengbin93 已提交
100 101
                "third_sell": xd_p
            })
102
            zs_xd = []
Z
zengbin93 已提交
103 104 105 106
        else:
            zs_xd.append(xd_p)

    if len(zs_xd) >= 5:
107 108
        zs_d = max([x['xd'] for x in zs_xd[:4] if x['fx_mark'] == 'd'])
        zs_g = min([x['xd'] for x in zs_xd[:4] if x['fx_mark'] == 'g'])
Z
zengbin93 已提交
109
        if zs_g > zs_d:
110
            zn_points = zs_xd[3:]
Z
zengbin93 已提交
111 112 113 114 115 116 117
            k_zs.append({
                'ZD': zs_d,
                "ZG": zs_g,
                'G': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'GG': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'g']),
                'D': max([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
                'DD': min([x['xd'] for x in zs_xd if x['fx_mark'] == 'd']),
118
                'start_point': zs_xd[1],
Z
zengbin93 已提交
119
                'end_point': None,
120
                "zn": __get_zn(zn_points),
Z
zengbin93 已提交
121 122
                "points": zs_xd,
            })
Z
zengbin93 已提交
123 124
    return k_zs

Z
zengbin93 已提交
125

126
def has_gap(k1, k2, min_gap=0.002):
Z
zengbin93 已提交
127
    """判断 k1, k2 之间是否有缺口"""
Z
zengbin93 已提交
128
    assert k2['dt'] > k1['dt']
129 130
    if k1['high'] < k2['low'] * (1-min_gap) \
            or k2['high'] < k1['low'] * (1-min_gap):
Z
zengbin93 已提交
131 132 133 134 135
        return True
    else:
        return False


136 137 138
def make_standard_seq(bi_seq):
    """计算标准特征序列

Z
zengbin93 已提交
139 140
    :param bi_seq: list of dict
        笔标记序列
141
    :return: list of dict
Z
zengbin93 已提交
142
        标准特征序列
143 144 145 146 147 148 149 150
    """
    if bi_seq[0]['fx_mark'] == 'd':
        direction = "up"
    elif bi_seq[0]['fx_mark'] == 'g':
        direction = "down"
    else:
        raise ValueError

151 152 153
    raw_seq = [{"start_dt": bi_seq[i]['dt'], "end_dt": bi_seq[i+1]['dt'],
                'high': max(bi_seq[i]['bi'], bi_seq[i + 1]['bi']),
                'low': min(bi_seq[i]['bi'], bi_seq[i + 1]['bi'])}
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
               for i in range(1, len(bi_seq), 2) if i <= len(bi_seq) - 2]

    seq = []
    for row in raw_seq:
        if not seq:
            seq.append(row)
            continue
        last = seq[-1]
        cur_h, cur_l = row['high'], row['low']
        last_h, last_l = last['high'], last['low']

        # 左包含 or 右包含
        if (cur_h <= last_h and cur_l >= last_l) or (cur_h >= last_h and cur_l <= last_l):
            seq.pop(-1)
            # 有包含关系,按方向分别处理
            if direction == "up":
                last_h = max(last_h, cur_h)
                last_l = max(last_l, cur_l)
            elif direction == "down":
                last_h = min(last_h, cur_h)
                last_l = min(last_l, cur_l)
            else:
                raise ValueError
177
            seq.append({"start_dt": last['start_dt'], "end_dt": row['end_dt'], "high": last_h, "low": last_l})
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
        else:
            seq.append(row)
    return seq


def is_valid_xd(bi_seq1, bi_seq2, bi_seq3):
    """判断线段标记是否有效(第二个线段标记)

    :param bi_seq1: list of dict
        第一个线段标记到第二个线段标记之间的笔序列
    :param bi_seq2:
        第二个线段标记到第三个线段标记之间的笔序列
    :param bi_seq3:
        第三个线段标记之后的笔序列
    :return:
    """
194
    assert bi_seq2[0]['dt'] == bi_seq1[-1]['dt'] and bi_seq3[0]['dt'] == bi_seq2[-1]['dt']
195 196

    standard_bi_seq1 = make_standard_seq(bi_seq1)
197 198
    if len(standard_bi_seq1) == 0 or len(bi_seq2) < 4:
        return False
199 200

    # 第一种情况(向下线段)
Z
zengbin93 已提交
201 202
    # if bi_seq2[0]['fx_mark'] == 'd' and bi_seq2[1]['bi'] >= standard_bi_seq1[-1]['low']:
    if bi_seq2[0]['fx_mark'] == 'd' and bi_seq2[1]['bi'] >= min([x['low'] for x in standard_bi_seq1]):
203 204 205 206
        if bi_seq2[-1]['bi'] < bi_seq2[1]['bi']:
            return False

    # 第一种情况(向上线段)
Z
zengbin93 已提交
207 208
    # if bi_seq2[0]['fx_mark'] == 'g' and bi_seq2[1]['bi'] <= standard_bi_seq1[-1]['high']:
    if bi_seq2[0]['fx_mark'] == 'g' and bi_seq2[1]['bi'] <= max([x['high'] for x in standard_bi_seq1]):
209 210 211 212
        if bi_seq2[-1]['bi'] > bi_seq2[1]['bi']:
            return False

    # 第二种情况(向下线段)
Z
zengbin93 已提交
213 214
    # if bi_seq2[0]['fx_mark'] == 'd' and bi_seq2[1]['bi'] < standard_bi_seq1[-1]['low']:
    if bi_seq2[0]['fx_mark'] == 'd' and bi_seq2[1]['bi'] < min([x['low'] for x in standard_bi_seq1]):
215
        bi_seq2.extend(bi_seq3[1:])
216
        standard_bi_seq2 = make_standard_seq(bi_seq2)
217 218 219
        if len(standard_bi_seq2) < 3:
            return False

220
        standard_bi_seq2_g = []
221
        for i in range(1, len(standard_bi_seq2) - 1):
222 223 224
            bi1, bi2, bi3 = standard_bi_seq2[i-1: i+2]
            if bi1['high'] < bi2['high'] > bi3['high']:
                standard_bi_seq2_g.append(bi2)
225

Z
zengbin93 已提交
226 227
                # 特征序列顶分型完全在底分型区间,返回 False
                if min(bi1['low'], bi2['low'], bi3['low']) < bi_seq2[0]['bi']:
228 229
                    return False

230 231 232 233
        if len(standard_bi_seq2_g) == 0:
            return False

    # 第二种情况(向上线段)
Z
zengbin93 已提交
234 235
    # if bi_seq2[0]['fx_mark'] == 'g' and bi_seq2[1]['bi'] > standard_bi_seq1[-1]['high']:
    if bi_seq2[0]['fx_mark'] == 'g' and bi_seq2[1]['bi'] > max([x['high'] for x in standard_bi_seq1]):
236
        bi_seq2.extend(bi_seq3[1:])
237
        standard_bi_seq2 = make_standard_seq(bi_seq2)
238 239 240
        if len(standard_bi_seq2) < 3:
            return False

241
        standard_bi_seq2_d = []
242
        for i in range(1, len(standard_bi_seq2) - 1):
243 244 245
            bi1, bi2, bi3 = standard_bi_seq2[i-1: i+2]
            if bi1['low'] > bi2['low'] < bi3['low']:
                standard_bi_seq2_d.append(bi2)
246

Z
zengbin93 已提交
247 248
                # 特征序列的底分型在顶分型区间,返回 False
                if max(bi1['high'], bi2['high'], bi3['high']) > bi_seq2[0]['bi']:
249
                    return False
Z
zengbin93 已提交
250

251 252 253 254
        if len(standard_bi_seq2_d) == 0:
            return False
    return True

Z
zengbin93 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279

def get_potential_xd(bi_points):
    """获取潜在线段标记点

    :param bi_points: list of dict
        笔标记点
    :return: list of dict
        潜在线段标记点
    """
    xd_p = []
    bi_d = [x for x in bi_points if x['fx_mark'] == 'd']
    bi_g = [x for x in bi_points if x['fx_mark'] == 'g']
    for i in range(1, len(bi_d) - 1):
        d1, d2, d3 = bi_d[i - 1: i + 2]
        if d1['bi'] > d2['bi'] < d3['bi']:
            xd_p.append(d2)
    for j in range(1, len(bi_g) - 1):
        g1, g2, g3 = bi_g[j - 1: j + 2]
        if g1['bi'] < g2['bi'] > g3['bi']:
            xd_p.append(g2)

    xd_p = sorted(xd_p, key=lambda x: x['dt'], reverse=False)
    return xd_p


Z
zengbin93 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
def check_jing(fd1, fd2, fd3, fd4, fd5):
    """检查最近5个分段走势是否构成井

    井的定义:
        12345,五段,是构造井的基本形态,形成井的位置肯定是5,而5出井的
        前提条件是对于向上5至少比3和1其中之一高,向下反过来; 并且,234
        构成一个中枢。

        井只有两类,大井和小井(以向上为例):
        大井对应的形式是:12345向上,5最高3次之1最低,力度上1大于3,3大于5;
        小井对应的形式是:
            1:12345向上,3最高5次之1最低,力度上5的力度比1小,注意这时候
               不需要再考虑5和3的关系了,因为5比3低,所以不需要考虑力度。
            2:12345向上,5最高3次之1最低,力度上1大于5,5大于3。

        小井的构造,关键是满足5一定至少大于1、3中的一个。
        注意有一种情况不归为井:就是12345向上,1的力度最小,5的力度次之,3的力度最大此类不算井,
        因为345后面必然还有走势在67的时候才能再判断,个中道理各位好好体会。


    fd 为 dict 对象,表示一段走势,可以是笔、线段,样例如下:

    fd = {
        "start_dt": "",
        "end_dt": "",
        "power": 0,         # 力度
        "direction": "up",
        "high": 0,
        "low": 0,
        "mode": "bi"
    }

    假定最近一段走势为第N段;则 fd1 为第N-4段走势, fd2为第N-3段走势,
    fd3为第N-2段走势, fd4为第N-1段走势, fd5为第N段走势

    """
    assert fd1['direction'] == fd3['direction'] == fd5['direction']
    assert fd2['direction'] == fd4['direction']
    direction = fd1['direction']

    zs_g = min(fd2['high'], fd3['high'], fd4['high'])
    zs_d = max(fd2['low'], fd3['low'], fd4['low'])

    jing = {"jing": "没有出井", "notes": ""}

    # 1的力度最小,5的力度次之,3的力度最大,此类不算井
    if fd1['power'] < fd5['power'] < fd3['power']:
        jing['notes'] = "1的力度最小,5的力度次之,3的力度最大,此类不算井"
        return jing

    if zs_d < zs_g:     # 234有中枢的情况
        if direction == 'up' and fd5["high"] > min(fd3['high'], fd1['high']):

            # 大井对应的形式是:12345向上,5最高3次之1最低,力度上1大于3,3大于5
            if fd5["high"] > fd3['high'] > fd1['high'] and fd5['power'] < fd3['power'] < fd1['power']:
                jing = {"jing": "向上大井", "notes": "12345向上,5最高3次之1最低,力度上1大于3,3大于5"}

            # 第一种小井:12345向上,3最高5次之1最低,力度上5的力度比1小
            if fd1['high'] < fd5['high'] < fd3['high'] and fd5['power'] < fd1['power']:
                jing = {"jing": "向上小井", "notes": "12345向上,3最高5次之1最低,力度上5的力度比1小"}

            # 第二种小井:12345向上,5最高3次之1最低,力度上1大于5,5大于3
            if fd5["high"] > fd3['high'] > fd1['high'] and fd1['power'] > fd5['power'] > fd3['power']:
                jing = {"jing": "向上小井", "notes": "12345向上,5最高3次之1最低,力度上1大于5,5大于3"}

        if direction == 'down' and fd5["low"] < max(fd3['low'], fd1['low']):

            # 大井对应的形式是:12345向下,5最低3次之1最高,力度上1大于3,3大于5
            if fd5['low'] < fd3['low'] < fd1['low'] and fd5['power'] < fd3['power'] < fd1['power']:
                jing = {"jing": "向下大井", "notes": "12345向下,5最低3次之1最高,力度上1大于3,3大于5"}

            # 第一种小井:12345向下,3最低5次之1最高,力度上5的力度比1小
            if fd1["low"] > fd5['low'] > fd3['low'] and fd5['power'] < fd1['power']:
                jing = {"jing": "向下小井", "notes": "12345向下,3最低5次之1最高,力度上5的力度比1小"}

            # 第二种小井:12345向下,5最低3次之1最高,力度上1大于5,5大于3
            if fd5['low'] < fd3['low'] < fd1['low'] and fd1['power'] > fd5['power'] > fd3['power']:
                jing = {"jing": "向下小井", "notes": "12345向下,5最低3次之1最高,力度上1大于5,5大于3"}
    else:
        # 第三种小井:12345类趋势,力度依次降低,可以看成小井
        if fd1['power'] > fd3['power'] > fd5['power']:
            if direction == 'up' and fd5["high"] > fd3['high'] > fd1['high']:
                jing = {"jing": "向上小井", "notes": "12345类上涨趋势,力度依次降低"}

            if direction == 'down' and fd5["low"] < fd3['low'] < fd1['low']:
                jing = {"jing": "向下小井", "notes": "12345类下跌趋势,力度依次降低"}

    return jing


def check_bei_chi(fd1, fd2, fd3, fd4, fd5):
    """检查最近5个分段走势是否有背驰

    fd 为 dict 对象,表示一段走势,可以是笔、线段,样例如下:

    fd = {
        "start_dt": "",
        "end_dt": "",
        "power": 0,         # 力度
        "direction": "up",
        "high": 0,
        "low": 0,
        "mode": "bi"
    }

    """
    assert fd1['direction'] == fd3['direction'] == fd5['direction']
    assert fd2['direction'] == fd4['direction']
    direction = fd1['direction']

    zs_g = min(fd2['high'], fd3['high'], fd4['high'])
    zs_d = max(fd2['low'], fd3['low'], fd4['low'])

    bc = {"bc": "没有背驰", "notes": ""}
    if max(fd5['power'], fd3['power'], fd1['power']) == fd5['power']:
        bc = {"bc": "没有背驰", "notes": "5的力度最大,没有背驰"}
        return bc

    if zs_d < zs_g:
        if fd5['power'] < fd1['power']:
            if direction == 'up' and fd5["high"] > min(fd3['high'], fd1['high']):
                bc = {"bc": "向上趋势背驰", "notes": "12345向上,234构成中枢,5最高,力度上1大于5"}

            if direction == 'down' and fd5["low"] < max(fd3['low'], fd1['low']):
                bc = {"bc": "向下趋势背驰", "notes": "12345向下,234构成中枢,5最低,力度上1大于5"}
    else:
        if fd5['power'] < fd3['power']:
            if direction == 'up' and fd5["high"] > fd3['high']:
                bc = {"bc": "向上盘整背驰", "notes": "12345向上,234不构成中枢,5最高,力度上1大于5"}

            if direction == 'down' and fd5["low"] < fd3['low']:
                bc = {"bc": "向下盘整背驰", "notes": "12345向下,234不构成中枢,5最低,力度上1大于5"}

    return bc


def check_third_bs(fd1, fd2, fd3, fd4, fd5):
    """输入5段走势,判断是否存在第三类买卖点"""
    zs_d = max(fd1['low'], fd2['low'], fd3['low'])
    zs_g = min(fd1['high'], fd2['high'], fd3['high'])

    third_bs = {"third_bs": "没有第三类买卖点", "notes": ""}

    if max(fd1['power'], fd2['power'], fd3['power'], fd4['power'], fd5['power']) != fd4['power']:
        third_bs = {"third_bs": "没有第三类买卖点", "notes": "第四段不是力度最大的段"}
        return third_bs

    if zs_g < zs_d:
        third_bs = {"third_bs": "没有第三类买卖点", "notes": "前三段不构成中枢,无第三类买卖点"}
    else:
        if fd4['low'] < zs_d and fd5['high'] < zs_d:
            third_bs = {"third_bs": "三卖", "notes": "前三段构成中枢,第四段向下离开,第五段不回中枢"}

        if fd4['high'] > zs_g and fd5['low'] > zs_g:
            third_bs = {"third_bs": "三买", "notes": "前三段构成中枢,第四段向上离开,第五段不回中枢"}
    return third_bs


Z
zengbin93 已提交
438
class KlineAnalyze:
Z
zengbin93 已提交
439
    def __init__(self, kline, name="本级别", bi_mode="new", max_xd_len=20, ma_params=(5, 34, 120), verbose=False):
440 441 442 443
        """

        :param kline: list or pd.DataFrame
        :param name: str
Z
zengbin93 已提交
444
        :param bi_mode: str
Z
zengbin93 已提交
445
            new 新笔;old 老笔;默认值为 new
Z
zengbin93 已提交
446 447
        :param max_xd_len: int
            线段标记序列的最大长度
Z
zengbin93 已提交
448 449
        :param ma_params: tuple of int
            均线系统参数
450 451
        :param verbose: bool
        """
Z
zengbin93 已提交
452 453
        self.name = name
        self.verbose = verbose
Z
zengbin93 已提交
454
        self.bi_mode = bi_mode
Z
zengbin93 已提交
455
        self.max_xd_len = max_xd_len
456
        self.ma_params = ma_params
457 458
        self.kline_raw = []  # 原始K线序列
        self.kline_new = []  # 去除包含关系的K线序列
Z
zengbin93 已提交
459

460 461 462 463
        # 辅助技术指标
        self.ma = []
        self.macd = []

Z
zengbin93 已提交
464 465 466 467 468
        # 分型、笔、线段
        self.fx_list = []
        self.bi_list = []
        self.xd_list = []

Z
zengbin93 已提交
469
        # 根据输入K线初始化
Z
zengbin93 已提交
470 471
        if isinstance(kline, pd.DataFrame):
            columns = kline.columns.to_list()
Z
zengbin93 已提交
472
            self.kline_raw = [{k: v for k, v in zip(columns, row)} for row in kline.values]
Z
zengbin93 已提交
473
        else:
Z
zengbin93 已提交
474
            self.kline_raw = kline
Z
zengbin93 已提交
475 476 477 478 479 480

        self.symbol = self.kline_raw[0]['symbol']
        self.start_dt = self.kline_raw[0]['dt']
        self.end_dt = self.kline_raw[-1]['dt']
        self.latest_price = self.kline_raw[-1]['close']

481
        self._update_ta()
Z
zengbin93 已提交
482 483 484 485
        self._update_kline_new()
        self._update_fx_list()
        self._update_bi_list()
        self._update_xd_list()
Z
zengbin93 已提交
486

487 488 489 490 491 492
    def _update_ta(self):
        """更新辅助技术指标"""
        if not self.ma:
            ma_temp = dict()
            close_ = np.array([x["close"] for x in self.kline_raw], dtype=np.double)
            for p in self.ma_params:
Z
zengbin93 已提交
493
                ma_temp['ma%i' % p] = ta.SMA(close_, p)
494 495 496 497 498 499 500 501 502

            for i in range(len(self.kline_raw)):
                ma_ = {'ma%i' % p: ma_temp['ma%i' % p][i] for p in self.ma_params}
                ma_.update({"dt": self.kline_raw[i]['dt']})
                self.ma.append(ma_)
        else:
            ma_ = {'ma%i' % p: sum([x['close'] for x in self.kline_raw[-p:]]) / p
                   for p in self.ma_params}
            ma_.update({"dt": self.kline_raw[-1]['dt']})
503 504 505
            if self.verbose:
                print("ma new: %s" % str(ma_))

506 507 508 509 510
            if self.kline_raw[-2]['dt'] == self.ma[-1]['dt']:
                self.ma.append(ma_)
            else:
                self.ma[-1] = ma_

511
        assert self.ma[-2]['dt'] == self.kline_raw[-2]['dt']
512 513 514 515

        if not self.macd:
            close_ = np.array([x["close"] for x in self.kline_raw], dtype=np.double)
            # m1 is diff; m2 is dea; m3 is macd
Z
zengbin93 已提交
516
            m1, m2, m3 = ta.MACD(close_, fastperiod=12, slowperiod=26, signalperiod=9)
517 518 519 520 521 522 523 524 525 526
            for i in range(len(self.kline_raw)):
                self.macd.append({
                    "dt": self.kline_raw[i]['dt'],
                    "diff": m1[i],
                    "dea": m2[i],
                    "macd": m3[i]
                })
        else:
            close_ = np.array([x["close"] for x in self.kline_raw[-200:]], dtype=np.double)
            # m1 is diff; m2 is dea; m3 is macd
Z
zengbin93 已提交
527
            m1, m2, m3 = ta.MACD(close_, fastperiod=12, slowperiod=26, signalperiod=9)
528
            macd_ = {
529 530 531 532 533
                "dt": self.kline_raw[-1]['dt'],
                "diff": m1[-1],
                "dea": m2[-1],
                "macd": m3[-1]
            }
534 535 536 537
            if self.verbose:
                print("macd new: %s" % str(macd_))

            if self.kline_raw[-2]['dt'] == self.macd[-1]['dt']:
538 539 540 541
                self.macd.append(macd_)
            else:
                self.macd[-1] = macd_

542 543
        assert self.macd[-2]['dt'] == self.kline_raw[-2]['dt']

Z
zengbin93 已提交
544
    def _update_kline_new(self):
Z
zengbin93 已提交
545
        """更新去除包含关系的K线序列"""
Z
zengbin93 已提交
546
        if len(self.kline_new) < 4:
Z
zengbin93 已提交
547
            for x in self.kline_raw[:4]:
Z
zengbin93 已提交
548
                self.kline_new.append(dict(x))
Z
zengbin93 已提交
549 550

        # 新K线只会对最后一个去除包含关系K线的结果产生影响
Z
zengbin93 已提交
551
        self.kline_new = self.kline_new[:-2]
Z
zengbin93 已提交
552 553 554 555 556
        if len(self.kline_new) <= 4:
            right_k = [x for x in self.kline_raw if x['dt'] > self.kline_new[-1]['dt']]
        else:
            right_k = [x for x in self.kline_raw[-100:] if x['dt'] > self.kline_new[-1]['dt']]

Z
zengbin93 已提交
557 558
        if len(right_k) == 0:
            return
Z
zengbin93 已提交
559 560 561 562

        for k in right_k:
            k = dict(k)
            last_kn = self.kline_new[-1]
Z
zengbin93 已提交
563 564 565 566
            if self.kline_new[-1]['high'] > self.kline_new[-2]['high']:
                direction = "up"
            else:
                direction = "down"
Z
zengbin93 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591

            # 判断是否存在包含关系
            cur_h, cur_l = k['high'], k['low']
            last_h, last_l = last_kn['high'], last_kn['low']
            if (cur_h <= last_h and cur_l >= last_l) or (cur_h >= last_h and cur_l <= last_l):
                self.kline_new.pop(-1)
                # 有包含关系,按方向分别处理
                if direction == "up":
                    last_h = max(last_h, cur_h)
                    last_l = max(last_l, cur_l)
                elif direction == "down":
                    last_h = min(last_h, cur_h)
                    last_l = min(last_l, cur_l)
                else:
                    raise ValueError

                k.update({"high": last_h, "low": last_l})
                # 保留红绿不变
                if k['open'] >= k['close']:
                    k.update({"open": last_h, "close": last_l})
                else:
                    k.update({"open": last_l, "close": last_h})
            self.kline_new.append(k)

    def _update_fx_list(self):
Z
zengbin93 已提交
592
        """更新分型序列"""
Z
zengbin93 已提交
593 594 595 596 597 598 599
        if len(self.kline_new) < 3:
            return

        self.fx_list = self.fx_list[:-1]
        if len(self.fx_list) == 0:
            kn = self.kline_new
        else:
600
            kn = [x for x in self.kline_new[-100:] if x['dt'] >= self.fx_list[-1]['dt']]
Z
zengbin93 已提交
601 602

        i = 1
603 604
        while i <= len(kn) - 2:
            k1, k2, k3 = kn[i - 1: i + 2]
Z
zengbin93 已提交
605 606 607 608 609 610
            fx_elements = [k1, k2, k3]
            if has_gap(k1, k2):
                fx_elements.pop(0)

            if has_gap(k2, k3):
                fx_elements.pop(-1)
Z
zengbin93 已提交
611 612 613

            if k1['high'] < k2['high'] > k3['high']:
                if self.verbose:
Z
zengbin93 已提交
614
                    print("顶分型:{} - {} - {}".format(k1['dt'], k2['dt'], k3['dt']))
Z
zengbin93 已提交
615 616 617 618
                fx = {
                    "dt": k2['dt'],
                    "fx_mark": "g",
                    "fx": k2['high'],
Z
zengbin93 已提交
619 620
                    "start_dt": k1['dt'],
                    "end_dt": k3['dt'],
Z
zengbin93 已提交
621
                    "fx_high": k2['high'],
Z
zengbin93 已提交
622
                    "fx_low": min([x['low'] for x in fx_elements]),
Z
zengbin93 已提交
623 624 625 626 627
                }
                self.fx_list.append(fx)

            elif k1['low'] > k2['low'] < k3['low']:
                if self.verbose:
Z
zengbin93 已提交
628
                    print("底分型:{} - {} - {}".format(k1['dt'], k2['dt'], k3['dt']))
Z
zengbin93 已提交
629 630 631 632
                fx = {
                    "dt": k2['dt'],
                    "fx_mark": "d",
                    "fx": k2['low'],
Z
zengbin93 已提交
633 634 635
                    "start_dt": k1['dt'],
                    "end_dt": k3['dt'],
                    "fx_high": max([x['high'] for x in fx_elements]),
Z
zengbin93 已提交
636 637 638 639 640 641
                    "fx_low": k2['low'],
                }
                self.fx_list.append(fx)

            else:
                if self.verbose:
Z
zengbin93 已提交
642
                    print("无分型:{} - {} - {}".format(k1['dt'], k2['dt'], k3['dt']))
Z
zengbin93 已提交
643 644 645
            i += 1

    def _update_bi_list(self):
646
        """更新笔序列"""
Z
zengbin93 已提交
647 648 649
        if len(self.fx_list) < 2:
            return

650
        self.bi_list = self.bi_list[:-2]
Z
zengbin93 已提交
651
        if len(self.bi_list) == 0:
Z
zengbin93 已提交
652
            for fx in self.fx_list[:2]:
Z
zengbin93 已提交
653 654 655 656
                bi = dict(fx)
                bi['bi'] = bi.pop('fx')
                self.bi_list.append(bi)

Z
zengbin93 已提交
657 658
        if len(self.bi_list) <= 2:
            right_fx = [x for x in self.fx_list if x['dt'] > self.bi_list[-1]['dt']]
Z
zengbin93 已提交
659 660 661 662 663 664
            if self.bi_mode == "old":
                right_kn = [x for x in self.kline_new if x['dt'] >= self.bi_list[-1]['dt']]
            elif self.bi_mode == 'new':
                right_kn = [x for x in self.kline_raw if x['dt'] >= self.bi_list[-1]['dt']]
            else:
                raise ValueError
Z
zengbin93 已提交
665
        else:
Z
zengbin93 已提交
666 667 668 669 670 671 672
            right_fx = [x for x in self.fx_list[-50:] if x['dt'] > self.bi_list[-1]['dt']]
            if self.bi_mode == "old":
                right_kn = [x for x in self.kline_new[-300:] if x['dt'] >= self.bi_list[-1]['dt']]
            elif self.bi_mode == 'new':
                right_kn = [x for x in self.kline_raw[-300:] if x['dt'] >= self.bi_list[-1]['dt']]
            else:
                raise ValueError
Z
zengbin93 已提交
673 674 675 676 677 678 679 680 681

        for fx in right_fx:
            last_bi = self.bi_list[-1]
            bi = dict(fx)
            bi['bi'] = bi.pop('fx')
            if last_bi['fx_mark'] == fx['fx_mark']:
                if (last_bi['fx_mark'] == 'g' and last_bi['bi'] < bi['bi']) \
                        or (last_bi['fx_mark'] == 'd' and last_bi['bi'] > bi['bi']):
                    if self.verbose:
Z
zengbin93 已提交
682
                        print("笔标记移动:from {} to {}".format(self.bi_list[-1], bi))
Z
zengbin93 已提交
683 684
                    self.bi_list[-1] = bi
            else:
Z
zengbin93 已提交
685 686 687 688 689 690 691 692 693 694 695 696
                kn_inside = [x for x in right_kn if last_bi['end_dt'] < x['dt'] < bi['start_dt']]
                if len(kn_inside) <= 0:
                    continue

                # 确保相邻两个顶底之间不存在包含关系
                if (last_bi['fx_mark'] == 'g' and bi['fx_low'] < last_bi['fx_low']
                    and bi['fx_high'] < last_bi['fx_high']) or \
                        (last_bi['fx_mark'] == 'd' and bi['fx_high'] > last_bi['fx_high']
                         and bi['fx_low'] > last_bi['fx_low']):
                    if self.verbose:
                        print("新增笔标记:{}".format(bi))
                    self.bi_list.append(bi)
Z
zengbin93 已提交
697

Z
zengbin93 已提交
698 699 700 701 702
        if (self.bi_list[-1]['fx_mark'] == 'd' and self.kline_new[-1]['low'] < self.bi_list[-1]['bi']) \
                or (self.bi_list[-1]['fx_mark'] == 'g' and self.kline_new[-1]['high'] > self.bi_list[-1]['bi']):
            if self.verbose:
                print("最后一个笔标记无效,{}".format(self.bi_list[-1]))
            self.bi_list.pop(-1)
Z
zengbin93 已提交
703

704 705 706 707
    def _update_xd_list_v1(self):
        """更新线段序列"""
        if len(self.bi_list) < 4:
            return
Z
zengbin93 已提交
708

Z
zengbin93 已提交
709
        self.xd_list = []
710 711 712 713 714 715
        if len(self.xd_list) == 0:
            for i in range(3):
                xd = dict(self.bi_list[i])
                xd['xd'] = xd.pop('bi')
                self.xd_list.append(xd)

Z
zengbin93 已提交
716
        right_bi = [x for x in self.bi_list if x['dt'] >= self.xd_list[-1]['dt']]
Z
zengbin93 已提交
717

Z
zengbin93 已提交
718
        xd_p = get_potential_xd(right_bi)
719 720 721 722 723 724 725 726 727 728
        for xp in xd_p:
            xd = dict(xp)
            xd['xd'] = xd.pop('bi')
            last_xd = self.xd_list[-1]
            if last_xd['fx_mark'] == xd['fx_mark']:
                if (last_xd['fx_mark'] == 'd' and last_xd['xd'] > xd['xd']) \
                        or (last_xd['fx_mark'] == 'g' and last_xd['xd'] < xd['xd']):
                    if self.verbose:
                        print("更新线段标记:from {} to {}".format(last_xd, xd))
                    self.xd_list[-1] = xd
Z
zengbin93 已提交
729
            else:
730 731 732
                if (last_xd['fx_mark'] == 'd' and last_xd['xd'] > xd['xd']) \
                        or (last_xd['fx_mark'] == 'g' and last_xd['xd'] < xd['xd']):
                    continue
Z
zengbin93 已提交
733

734 735 736 737 738 739
                bi_inside = [x for x in right_bi if last_xd['dt'] <= x['dt'] <= xd['dt']]
                if len(bi_inside) < 4:
                    if self.verbose:
                        print("{} - {} 之间笔标记数量少于4,跳过".format(last_xd['dt'], xd['dt']))
                    continue
                else:
Z
zengbin93 已提交
740
                    self.xd_list.append(xd)
741

742
    def _xd_after_process(self):
Z
zengbin93 已提交
743
        """线段标记后处理,使用标准特征序列判断线段标记是否成立"""
744
        if not len(self.xd_list) > 4:
Z
zengbin93 已提交
745 746
            return

747 748 749 750 751 752
        keep_xd_index = []
        for i in range(1, len(self.xd_list)-2):
            xd1, xd2, xd3, xd4 = self.xd_list[i-1: i+3]
            bi_seq1 = [x for x in self.bi_list if xd2['dt'] >= x['dt'] >= xd1['dt']]
            bi_seq2 = [x for x in self.bi_list if xd3['dt'] >= x['dt'] >= xd2['dt']]
            bi_seq3 = [x for x in self.bi_list if xd4['dt'] >= x['dt'] >= xd3['dt']]
Z
zengbin93 已提交
753 754 755
            if len(bi_seq1) == 0 or len(bi_seq2) == 0 or len(bi_seq3) == 0:
                continue

756 757 758
            if is_valid_xd(bi_seq1, bi_seq2, bi_seq3):
                keep_xd_index.append(i)

Z
zengbin93 已提交
759
        # 处理最近一个确定的线段标记
760 761 762
        bi_seq1 = [x for x in self.bi_list if self.xd_list[-2]['dt'] >= x['dt'] >= self.xd_list[-3]['dt']]
        bi_seq2 = [x for x in self.bi_list if self.xd_list[-1]['dt'] >= x['dt'] >= self.xd_list[-2]['dt']]
        bi_seq3 = [x for x in self.bi_list if x['dt'] >= self.xd_list[-1]['dt']]
Z
zengbin93 已提交
763 764
        if not (len(bi_seq1) == 0 or len(bi_seq2) == 0 or len(bi_seq3) == 0):
            if is_valid_xd(bi_seq1, bi_seq2, bi_seq3):
Z
zengbin93 已提交
765 766 767 768 769
                keep_xd_index.append(len(self.xd_list) - 2)

        # 处理最近一个未确定的线段标记
        if len(bi_seq3) >= 4:
            keep_xd_index.append(len(self.xd_list) - 1)
770 771 772 773 774

        new_xd_list = []
        for j in keep_xd_index:
            if not new_xd_list:
                new_xd_list.append(self.xd_list[j])
Z
zengbin93 已提交
775
            else:
776 777 778 779
                if new_xd_list[-1]['fx_mark'] == self.xd_list[j]['fx_mark']:
                    if (new_xd_list[-1]['fx_mark'] == 'd' and new_xd_list[-1]['xd'] > self.xd_list[j]['xd']) \
                            or (new_xd_list[-1]['fx_mark'] == 'g' and new_xd_list[-1]['xd'] < self.xd_list[j]['xd']):
                        new_xd_list[-1] = self.xd_list[j]
Z
zengbin93 已提交
780
                else:
781 782
                    new_xd_list.append(self.xd_list[j])
        self.xd_list = new_xd_list
783

784
        # 针对最近一个线段标记处理
Z
zengbin93 已提交
785 786 787 788
        if self.xd_list:
            if (self.xd_list[-1]['fx_mark'] == 'd' and self.bi_list[-1]['bi'] < self.xd_list[-1]['xd']) \
                    or (self.xd_list[-1]['fx_mark'] == 'g' and self.bi_list[-1]['bi'] > self.xd_list[-1]['xd']):
                self.xd_list.pop(-1)
789

790 791
    def _update_xd_list(self):
        self._update_xd_list_v1()
792
        self._xd_after_process()
793

Z
zengbin93 已提交
794 795 796 797 798 799 800 801 802 803 804
    def update(self, k):
        """更新分析结果

        :param k: dict
            单根K线对象,样例如下
            {'symbol': '000001.SH',
             'dt': Timestamp('2020-07-16 15:00:00'),
             'open': 3356.11,
             'close': 3210.1,
             'high': 3373.53,
             'low': 3209.76,
Z
zengbin93 已提交
805
             'vol': 486366915.0}
Z
zengbin93 已提交
806
        """
Z
zengbin93 已提交
807 808
        if self.verbose:
            print("=" * 100)
Z
zengbin93 已提交
809
            print("输入新K线:{}".format(k))
Z
zengbin93 已提交
810
        if not self.kline_raw or k['open'] != self.kline_raw[-1]['open']:
Z
zengbin93 已提交
811 812 813
            self.kline_raw.append(k)
        else:
            if self.verbose:
Z
zengbin93 已提交
814
                print("输入K线处于未完成状态,更新:replace {} with {}".format(self.kline_raw[-1], k))
Z
zengbin93 已提交
815 816
            self.kline_raw[-1] = k

817
        self._update_ta()
Z
zengbin93 已提交
818 819 820
        self._update_kline_new()
        self._update_fx_list()
        self._update_bi_list()
Z
zengbin93 已提交
821 822
        self._update_xd_list()

Z
zengbin93 已提交
823 824 825
        self.end_dt = self.kline_raw[-1]['dt']
        self.latest_price = self.kline_raw[-1]['close']

Z
zengbin93 已提交
826 827 828 829 830 831
        if len(self.xd_list) > self.max_xd_len:
            last_dt = self.xd_list[-self.max_xd_len:][0]['dt']
            self.kline_raw = [x for x in self.kline_raw if x['dt'] > last_dt]
            self.kline_new = [x for x in self.kline_new if x['dt'] > last_dt]
            self.ma = [x for x in self.ma if x['dt'] > last_dt]
            self.macd = [x for x in self.macd if x['dt'] > last_dt]
Z
zengbin93 已提交
832 833 834 835
            self.fx_list = [x for x in self.fx_list if x['dt'] > last_dt]
            self.bi_list = [x for x in self.bi_list if x['dt'] > last_dt]
            self.xd_list = [x for x in self.xd_list if x['dt'] > last_dt]

Z
zengbin93 已提交
836 837
        if self.verbose:
            print("更新结束\n\n")
Z
zengbin93 已提交
838

Z
zengbin93 已提交
839
    def to_df(self, ma_params=(5, 20), use_macd=False, max_count=1000, mode="raw"):
Z
zengbin93 已提交
840 841 842 843 844 845
        """整理成 df 输出

        :param ma_params: tuple of int
            均线系统参数
        :param use_macd: bool
        :param max_count: int
Z
zengbin93 已提交
846 847
        :param mode: str
            使用K线类型, raw = 原始K线,new = 去除包含关系的K线
Z
zengbin93 已提交
848 849
        :return: pd.DataFrame
        """
Z
zengbin93 已提交
850 851 852 853 854 855 856
        if mode == "raw":
            bars = self.kline_raw[-max_count:]
        elif mode == "new":
            bars = self.kline_raw[-max_count:]
        else:
            raise ValueError

Z
zengbin93 已提交
857 858 859
        fx_list = {x["dt"]: {"fx_mark": x["fx_mark"], "fx": x['fx']} for x in self.fx_list[-(max_count // 2):]}
        bi_list = {x["dt"]: {"fx_mark": x["fx_mark"], "bi": x['bi']} for x in self.bi_list[-(max_count // 4):]}
        xd_list = {x["dt"]: {"fx_mark": x["fx_mark"], "xd": x['xd']} for x in self.xd_list[-(max_count // 8):]}
Z
zengbin93 已提交
860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877
        results = []
        for k in bars:
            k['fx_mark'], k['fx'], k['bi'], k['xd'] = "o", None, None, None
            fx_ = fx_list.get(k['dt'], None)
            bi_ = bi_list.get(k['dt'], None)
            xd_ = xd_list.get(k['dt'], None)
            if fx_:
                k['fx_mark'] = fx_["fx_mark"]
                k['fx'] = fx_["fx"]

            if bi_:
                k['bi'] = bi_["bi"]

            if xd_:
                k['xd'] = xd_["xd"]

            results.append(k)
        df = pd.DataFrame(results)
Z
zengbin93 已提交
878 879
        for p in ma_params:
            df.loc[:, "ma{}".format(p)] = ta.SMA(df.close.values, p)
Z
zengbin93 已提交
880
        if use_macd:
Z
zengbin93 已提交
881 882 883 884
            diff, dea, macd = ta.MACD(df.close.values)
            df.loc[:, "diff"] = diff
            df.loc[:, "dea"] = diff
            df.loc[:, "macd"] = diff
Z
zengbin93 已提交
885 886 887 888 889 890 891 892 893 894 895 896 897 898 899
        return df

    def to_image(self, file_image, mav=(5, 20, 120, 250), max_k_count=1000, dpi=50):
        """保存成图片

        :param file_image: str
            图片名称,支持 jpg/png/svg 格式,注意后缀
        :param mav: tuple of int
            均线系统参数
        :param max_k_count: int
            设定最大K线数量,这个值越大,生成的图片越长
        :param dpi: int
            图片分辨率
        :return:
        """
Z
zengbin93 已提交
900
        ka_to_image(self, file_image=file_image, mav=mav, max_k_count=max_k_count, dpi=dpi)
Z
zengbin93 已提交
901

902
    def is_bei_chi(self, zs1, zs2, mode="bi", adjust=0.9, last_index: int = None):
903
        """判断 zs1 对 zs2 是否有背驰
Z
zengbin93 已提交
904

905 906 907 908 909 910 911 912 913 914 915 916 917 918 919
        注意:力度的比较,并没有要求两段走势方向一致;但是如果两段走势之间存在包含关系,这样的力度比较是没有意义的。

        :param zs1: dict
            用于比较的走势,通常是最近的走势,示例如下:
            zs1 = {"start_dt": "2020-02-20 11:30:00", "end_dt": "2020-02-20 14:30:00", "direction": "up"}
        :param zs2: dict
            被比较的走势,通常是较前的走势,示例如下:
            zs2 = {"start_dt": "2020-02-21 11:30:00", "end_dt": "2020-02-21 14:30:00", "direction": "down"}
        :param mode: str
            default `bi`, optional value [`xd`, `bi`]
            xd  判断两个线段之间是否存在背驰
            bi  判断两笔之间是否存在背驰
        :param adjust: float
            调整 zs2 的力度,建议设置范围在 0.6 ~ 1.0 之间,默认设置为 0.9;
            其作用是确保 zs1 相比于 zs2 的力度足够小。
920 921
        :param last_index: int
            在比较最后一个走势的时候,可以设置这个参数来提升速度,相当于只对 last_index 后面的K线进行力度比较
922 923 924 925 926 927 928 929
        :return: bool
        """
        assert zs1["start_dt"] > zs2["end_dt"], "zs1 必须是最近的走势,用于比较;zs2 必须是较前的走势,被比较。"
        assert zs1["start_dt"] < zs1["end_dt"], "走势的时间区间定义错误,必须满足 start_dt < end_dt"
        assert zs2["start_dt"] < zs2["end_dt"], "走势的时间区间定义错误,必须满足 start_dt < end_dt"

        min_dt = min(zs1["start_dt"], zs2["start_dt"])
        max_dt = max(zs1["end_dt"], zs2["end_dt"])
930 931 932 933 934
        if last_index:
            macd = self.macd[-last_index:]
        else:
            macd = self.macd
        macd_ = [x for x in macd if x['dt'] >= min_dt]
Z
zengbin93 已提交
935
        macd_ = [x for x in macd_ if max_dt >= x['dt']]
936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968
        k1 = [x for x in macd_ if zs1["end_dt"] >= x['dt'] >= zs1["start_dt"]]
        k2 = [x for x in macd_ if zs2["end_dt"] >= x['dt'] >= zs2["start_dt"]]

        bc = False
        if mode == 'bi':
            macd_sum1 = sum([abs(x['macd']) for x in k1])
            macd_sum2 = sum([abs(x['macd']) for x in k2])
            # print("bi: ", macd_sum1, macd_sum2)
            if macd_sum1 < macd_sum2 * adjust:
                bc = True

        elif mode == 'xd':
            assert zs1['direction'] in ['down', 'up'], "走势的 direction 定义错误,可取值为 up 或 down"
            assert zs2['direction'] in ['down', 'up'], "走势的 direction 定义错误,可取值为 up 或 down"

            if zs1['direction'] == "down":
                macd_sum1 = sum([abs(x['macd']) for x in k1 if x['macd'] < 0])
            else:
                macd_sum1 = sum([abs(x['macd']) for x in k1 if x['macd'] > 0])

            if zs2['direction'] == "down":
                macd_sum2 = sum([abs(x['macd']) for x in k2 if x['macd'] < 0])
            else:
                macd_sum2 = sum([abs(x['macd']) for x in k2 if x['macd'] > 0])

            # print("xd: ", macd_sum1, macd_sum2)
            if macd_sum1 < macd_sum2 * adjust:
                bc = True

        else:
            raise ValueError("mode value error")

        return bc
Z
zengbin93 已提交
969

Z
zengbin93 已提交
970
    def get_sub_section(self, start_dt, end_dt, mode="bi", is_last=True):
Z
zengbin93 已提交
971 972 973 974 975 976 977
        """获取子区间

        :param start_dt: datetime
            子区间开始时间
        :param end_dt: datetime
            子区间结束时间
        :param mode: str
Z
zengbin93 已提交
978
            需要获取的子区间对象类型,可取值 ['kn', 'fx', 'bi', 'xd']
Z
zengbin93 已提交
979 980
        :param is_last: bool
            是否是最近一段子区间
Z
zengbin93 已提交
981 982
        :return: list of dict
        """
Z
zengbin93 已提交
983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006
        if mode == "kn":
            if is_last:
                points = self.kline_new[-200:]
            else:
                points = self.kline_new
        elif mode == "fx":
            if is_last:
                points = self.fx_list[-100:]
            else:
                points = self.fx_list
        elif mode == "bi":
            if is_last:
                points = self.bi_list[-50:]
            else:
                points = self.bi_list
        elif mode == "xd":
            if is_last:
                points = self.xd_list[-30:]
            else:
                points = self.xd_list
        else:
            raise ValueError

        return [x for x in points if end_dt >= x['dt'] >= start_dt]
Z
zengbin93 已提交
1007

Z
zengbin93 已提交
1008
    def calculate_macd_power(self, start_dt, end_dt, mode='bi', direction="up"):
1009
        """用 MACD 计算走势段(start_dt ~ end_dt)的力度
Z
zengbin93 已提交
1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021

        :param start_dt: datetime
            走势开始时间
        :param end_dt: datetime
            走势结束时间
        :param mode: str
            分段走势类型,默认值为 bi,可选值 ['bi', 'xd'],分别表示笔分段走势和线段分段走势
        :param direction: str
            线段分段走势计算力度需要指明方向,可选值 ['up', 'down']
        :return: float
            走势力度
        """
Z
zengbin93 已提交
1022
        fd_macd = [x for x in self.macd if end_dt >= x['dt'] >= start_dt]
Z
zengbin93 已提交
1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036

        if mode == 'bi':
            power = sum([abs(x['macd']) for x in fd_macd])
        elif mode == 'xd':
            if direction == 'up':
                power = sum([abs(x['macd']) for x in fd_macd if x['macd'] > 0])
            elif direction == 'down':
                power = sum([abs(x['macd']) for x in fd_macd if x['macd'] < 0])
            else:
                raise ValueError
        else:
            raise ValueError
        return power

Z
zengbin93 已提交
1037
    def calculate_vol_power(self, start_dt, end_dt):
1038 1039 1040 1041 1042 1043 1044 1045 1046
        """用 VOL 计算走势段(start_dt ~ end_dt)的力度

        :param start_dt: datetime
            走势开始时间
        :param end_dt: datetime
            走势结束时间
        :return: float
            走势力度
        """
Z
zengbin93 已提交
1047
        fd_vol = [x for x in self.kline_raw if end_dt >= x['dt'] >= start_dt]
1048 1049
        power = sum([x['vol'] for x in fd_vol])
        return int(power)
Z
zengbin93 已提交
1050

Z
zengbin93 已提交
1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121
    def get_latest_fd(self, n=6, mode="bi"):
        """获取最近的走势分段

        fd 为 dict 对象,表示一段走势,可以是笔、线段,样例如下:

        fd = {
            "start_dt": "",
            "end_dt": "",
            "power": 0,         # 力度
            "direction": "up",
            "high": 0,
            "low": 0,
            "mode": "bi"
        }

        :param n:
        :param mode:
        :return: list of dict
        """
        if mode == 'bi':
            points = self.bi_list[-(n + 1):]
        elif mode == 'xd':
            points = self.xd_list[-(n + 1):]
        else:
            raise ValueError

        res = []
        for i in range(len(points)-1):
            p1 = points[i]
            p2 = points[i+1]
            direction = "up" if p1[mode] < p2[mode] else "down"
            power = self.calculate_macd_power(start_dt=p1['dt'], end_dt=p2['dt'], mode=mode, direction=direction)
            res.append({
                "start_dt": p1['dt'],
                "end_dt": p2['dt'],
                "power": power,
                "direction": direction,
                "high": max(p1[mode], p2[mode]),
                "low": min(p1[mode], p2[mode]),
                "mode": mode
            })
        return res

    def get_last_fd(self, mode='bi'):
        """获取最后一个分段走势

        :param mode: str
            可选值 ['bi', 'xd'],默认值 'bi'
        :return:
        """
        if mode == 'bi':
            p1 = self.bi_list[-1]
            points = [x for x in self.fx_list[-60:] if x['dt'] >= p1['dt']]
            if len(points) < 2:
                return None

            if p1['fx_mark'] == 'd':
                direction = "up"
                max_fx = max([x['fx'] for x in points if x['fx_mark'] == 'g'])
                p2 = [x for x in points if x['fx'] == max_fx][0]
            elif p1['fx_mark'] == 'g':
                direction = "down"
                min_fx = min([x['fx'] for x in points if x['fx_mark'] == 'd'])
                p2 = [x for x in points if x['fx'] == min_fx][0]
            else:
                raise ValueError

            p2 = dict(p2)
            p2['bi'] = p2.pop('fx')

        elif mode == 'xd':
Z
zengbin93 已提交
1122 1123 1124
            if not self.xd_list:
                return None

Z
zengbin93 已提交
1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155
            p1 = self.xd_list[-1]
            points = [x for x in self.bi_list[-60:] if x['dt'] >= p1['dt']]
            if len(points) < 4:
                return None

            if p1['fx_mark'] == 'd':
                direction = "up"
                max_fx = max([x['bi'] for x in points if x['fx_mark'] == 'g'])
                p2 = [x for x in points if x['bi'] == max_fx][0]
            elif p1['fx_mark'] == 'g':
                direction = "down"
                min_fx = min([x['bi'] for x in points if x['fx_mark'] == 'd'])
                p2 = [x for x in points if x['bi'] == min_fx][0]
            else:
                raise ValueError

            p2 = dict(p2)
            p2['xd'] = p2.pop('bi')
        else:
            raise ValueError

        power = self.calculate_macd_power(start_dt=p1['dt'], end_dt=p2['dt'], mode=mode, direction=direction)
        return {
            "start_dt": p1['dt'],
            "end_dt": p2['dt'],
            "power": power,
            "direction": direction,
            "high": max(p1[mode], p2[mode]),
            "low": min(p1[mode], p2[mode]),
            "mode": mode
        }