提交 a41b8f35 编写于 作者: F Faylixe

️ fix heavy import on cli parsing

上级 0428a0bd
......@@ -5,10 +5,12 @@
Python oneliner script usage.
USAGE: python -m spleeter {train,evaluate,separate} ...
"""
# NOTE: disable TF logging before import.
from .utils.logging import configure_logger, logger
Notes:
All critical import involving TF, numpy or Pandas are deported to
command function scope to avoid heavy import on CLI evaluation,
leading to large bootstraping time.
"""
import json
......@@ -17,24 +19,14 @@ from itertools import product
from glob import glob
from os.path import join
from pathlib import Path
from typing import Any, Container, Dict, List
from typing import Container, Dict, List
from . import SpleeterError
from .audio import Codec
from .audio.adapter import AudioAdapter
from .options import *
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
from .separator import Separator
from .utils.configuration import load_configuration
from .utils.logging import configure_logger, logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import pandas as pd
import tensorflow as tf
from typer import Exit, Typer
# pylint: enable=import-error
......@@ -51,6 +43,14 @@ def train(
"""
Train a source separation model
"""
from .audio.adapter import AudioAdapter
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
from .utils.configuration import load_configuration
import tensorflow as tf
configure_logger(verbose)
audio_adapter = AudioAdapter.get(adapter)
audio_path = str(data)
......@@ -104,6 +104,9 @@ def separate(
"""
Separate audio file(s)
"""
from .audio.adapter import AudioAdapter
from .separator import Separator
configure_logger(verbose)
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
separator: Separator = Separator(
......@@ -144,6 +147,9 @@ def _compile_metrics(metrics_output_directory) -> Dict:
Dict:
Compiled metrics as dict.
"""
import pandas as pd
import numpy as np
songs = glob(join(metrics_output_directory, 'test/*.json'))
index = pd.MultiIndex.from_tuples(
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
......@@ -178,6 +184,8 @@ def evaluate(
"""
Evaluate a model on the musDB test dataset
"""
import numpy as np
configure_logger(verbose)
try:
import musdb
......
......@@ -12,11 +12,6 @@
from enum import Enum
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
......@@ -42,6 +37,12 @@ class STFTBackend(str, Enum):
@classmethod
def resolve(cls: type, backend: str) -> str:
# NOTE: import is resolved here to avoid performance issues on command
# evaluation.
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
if backend not in cls.__members__.values():
raise ValueError(f'Unsupported backend {backend}')
if backend == cls.AUTO:
......
......@@ -6,11 +6,10 @@
from tempfile import gettempdir
from os.path import join
from .separator import STFTBackend
from .audio import Codec
from .audio import Codec, STFTBackend
from typer import Argument, Option
from typer.models import ArgumentInfo, OptionInfo
from typer import Option
from typer.models import OptionInfo
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册