gen_header_for_bin_reduce.py 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import re

if sys.version_info[0] != 3 or sys.version_info[1] < 5:
    print('This script requires Python version 3.5')
    sys.exit(1)

import argparse
import json
import os
import subprocess
import tempfile
from pathlib import Path

MIDOUT_TRACE_MAGIC = 'midout_trace v1\n'

class HeaderGen:
    _dtypes = None
    _oprs = None
    _fout = None
    _elemwise_modes = None
    _has_netinfo = False
    _midout_files = None

    _file_without_hash = False

    def __init__(self):
        self._dtypes = set()
        self._oprs = set()
        self._elemwise_modes = set()
        self._graph_hashes = set()
        self._midout_files = []

    _megvii3_root_cache = None
    @classmethod
    def get_megvii3_root(cls):
        if cls._megvii3_root_cache is not None:
            return cls._megvii3_root_cache
        wd = Path(__file__).resolve().parent
        while wd.parent != wd:
           workspace_file = wd / 'WORKSPACE'
           if workspace_file.is_file():
               cls._megvii3_root_cache = str(wd)
               return cls._megvii3_root_cache
           wd = wd.parent
        raise RuntimeError('This script is supposed to run in megvii3.')

    def extend_netinfo(self, data):
        self._has_netinfo = True
        if 'hash' not in data:
            self._file_without_hash = True
        else:
            self._graph_hashes.add(str(data['hash']))
        for i in data['dtypes']:
            self._dtypes.add(i)
        for i in data['opr_types']:
            self._oprs.add(i)
        for i in data['elemwise_modes']:
            self._elemwise_modes.add(i)

    def extend_midout(self, fname):
        self._midout_files.append(fname)

    def generate(self, fout):
        self._fout = fout
        self._write_def('MGB_BINREDUCE_VERSION', '20190219')
        if self._has_netinfo:
            self._write_dtype()
            self._write_elemwise_modes()
            self._write_oprs()
            self._write_hash()
        self._write_midout()
        del self._fout

    def strip_opr_name_with_version(self, name):
        pos = len(name)
        t = re.search(r'V\d+$', name)
        if t:
            pos = t.start()
        return name[:pos]

    def _write_oprs(self):
        defs = ['}',  'namespace opr {']
        already_declare = set()
        already_instance = set()
        for i in self._oprs:
            i = self.strip_opr_name_with_version(i)
            if i in already_declare:
                continue
            else:
                already_declare.add(i)

            defs.append('class {};'.format(i))
        defs.append('}')
        defs.append('namespace serialization {')
        defs.append("""
            template<class Opr, class Callee>
            struct OprRegistryCaller {
            }; """)
        for i in sorted(self._oprs):
            i = self.strip_opr_name_with_version(i)
            if i in already_instance:
                continue
            else:
                already_instance.add(i)

            defs.append("""
                template<class Callee>
                struct OprRegistryCaller<opr::{}, Callee>: public
                    OprRegistryCallerDefaultImpl<Callee> {{
                }}; """.format(i))
        self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs)

    def _write_elemwise_modes(self):
        with tempfile.NamedTemporaryFile() as ftmp:
            fpath = os.path.realpath(ftmp.name)
            subprocess.check_call(
                ['./brain/megbrain/dnn/scripts/gen_param_defs.py',
                 '--write-enum-items', 'Elemwise:Mode',
                 './brain/megbrain/dnn/scripts/opr_param_defs.py',
                 fpath],
                cwd=self.get_megvii3_root()
            )

            with open(fpath) as fin:
                mode_list = [i.strip() for i in fin]

        for i in mode_list:
            if i in self._elemwise_modes:
                content = '_cb({})'.format(i)
            else:
                content = ''
            self._write_def(
                '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i), content)
        self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)',
                        '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)')

    def _write_dtype(self):
        if 'Float16' not in self._dtypes:
            # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
            # support in the past; however `FLOT16' is really a typo. We plan to
            # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
            # To prevent issues in the transition, we decide to define both
            # macros (`FLOT16' and `FLOAT16') here.
            #
            # In the future when the situation is settled and no one would ever
            # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
            # safely deleted.
            self._write_def('MEGDNN_DISABLE_FLOT16', 1)
            self._write_def('MEGDNN_DISABLE_FLOAT16', 1)

    def _write_hash(self):
        if self._file_without_hash:
            print('WARNING: network info has no graph hash. Using json file '
                  'generated by MegBrain >= 7.28.0 is recommended')
        else:
            defs = 'ULL,'.join(self._graph_hashes) + 'ULL'
            self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs)

    def _write_def(self, name, val):
        if isinstance(val, list):
            val = '\n'.join(val)
        val = str(val).strip().replace('\n', ' \\\n')
        self._fout.write('#define {} {}\n'.format(name, val))

    def _write_midout(self):
        if not self._midout_files:
            return

        gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout',
                           'gen_header.py')
        cvt = subprocess.run(
            [gen] + self._midout_files,
            stdout=subprocess.PIPE, check=True,
        ).stdout.decode('utf-8')
        self._fout.write('// midout \n')
        self._fout.write(cvt)

def main():
    parser = argparse.ArgumentParser(
        description='generate header file for reducing binary size by '
        'stripping unused oprs in a particular network; output file would '
        'be written to bin_reduce.h',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        'inputs', nargs='+',
        help='input files that describe specific traits of the network; '
        'can be one of the following:'
        '  1. json files generated by '
        'megbrain.serialize_comp_graph_to_file() in python; '
        '  2. trace files generated by midout library')
    parser.add_argument('-o', '--output', help='output file',
                        default=os.path.join(HeaderGen.get_megvii3_root(),
                                             'utils', 'bin_reduce.h'))
    args = parser.parse_args()

    gen = HeaderGen()
    for i in args.inputs:
        print('==== processing {}'.format(i))
        with open(i) as fin:
            if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
                gen.extend_midout(i)
            else:
                fin.seek(0)
                gen.extend_netinfo(json.loads(fin.read()))

    with open(args.output, 'w') as fout:
        gen.generate(fout)

if __name__ == '__main__':
    main()