Source code for dsdtools

from __future__ import print_function
from .audio_classes import Track, Source, Target
from . import evaluate
from os import path as op
from six.moves import map
import multiprocessing
import soundfile as sf
import collections
import numpy as np
import functools
import signal
import yaml
import glob
import tqdm
import os
import dsdtools


[docs]class DB(object): """ The dsdtools DB Object Parameters ---------- root_dir : str, optional dsdtools Root path. If set to `None` it will be read from the `DSD_PATH` environment variable subsets : str or list, optional select a _dsdtools_ subset `Dev` or `Test` (defaults to both) setup_file : str, optional _dsdtools_ Setup file in yaml format. Default is provided `dsd100.yaml` evaluation : str, {None, 'bss_eval', 'mir_eval'} Setup evaluation module and starts matlab if bsseval is enabled valid_ids : list[int] or int, optional select single or multiple _dsdtools_ items by ID that will be used for validation data (ie not included in the `Dev` set) Attributes ---------- setup_file : str path to yaml file. default: `setup.yaml` root_dir : str dsdtools Root path. Default is `DSD_PATH` env evaluation : bool Setup evaluation module mixtures_dir : str path to Mixture directory sources_dir : str path to Sources directory sources_names : list[str] list of names of sources targets_names : list[str] list of names of targets evaluator : BSSeval evaluator used for evaluation of estimates setup : Dict loaded yaml configuration Methods ------- load_dsd_tracks() Iterates through the dsdtools folder structure and returns ``Track`` objects test(user_function) Test the dsdtools processing evaluate() Run the evaluation run(user_function=None, estimates_dir=None, evaluate=False) Run the dsdtools processing, saving the estimates and optionally evaluate them """ def __init__( self, root_dir=None, setup_file=None, evaluation=None, valid_ids=None, ): if root_dir is None: if "DSD_PATH" in os.environ: self.root_dir = os.environ["DSD_PATH"] else: raise RuntimeError("Variable `DSD_PATH` has not been set.") else: self.root_dir = root_dir if setup_file is not None: setup_path = op.join(self.root_dir, setup_file) else: setup_path = os.path.join( dsdtools.__path__[0], 'configs', 'dsd100.yaml' ) with open(setup_path, 'r') as f: self.setup = yaml.load(f) self.mixtures_dir = op.join( self.root_dir, "Mixtures" ) self.sources_dir = op.join( self.root_dir, "Sources" ) if valid_ids is not None: if not isinstance(valid_ids, collections.Sequence): valid_ids = [valid_ids] self.valid_ids = valid_ids self.sources_names = list(self.setup['sources'].keys()) self.targets_names = list(self.setup['targets'].keys()) if evaluation is not None: self.evaluator = evaluate.BSSeval(evaluation)
[docs] def load_dsd_tracks(self, subsets=None, ids=None): """Parses the dsdtools folder structure, returns list of `Track` objects Parameters ========== subsets : list[str], optional select a _dsdtools_ subset `Dev` or `Test`. Defaults to both ids : list[int] or int, optional select single or multiple _dsdtools_ items by ID Returns ------- list[Track] return a list of ``Track`` Objects """ # parse all the mixtures if ids is not None: if not isinstance(ids, collections.Sequence): ids = [ids] if subsets is not None: if isinstance(subsets, str): subsets = [subsets] else: subsets = subsets if 'Valid' in subsets and 'Dev' in subsets: raise ValueError( "Cannot load Valid and Dev at the same time" ) else: subsets = ['Dev', 'Test'] tracks = [] if op.isdir(self.mixtures_dir): for subset in subsets: # For validation use Dev set and filter by ids later if subset == 'Valid': subset_folder = op.join(self.mixtures_dir, 'Dev') else: subset_folder = op.join(self.mixtures_dir, subset) for _, track_folders, _ in os.walk(subset_folder): for track_filename in sorted(track_folders): # create new dsd Track track = Track( filename=track_filename, path=op.join( op.join(subset_folder, track_filename), self.setup['mix'] ), subset=subset ) # add sources to track sources = {} for src, rel_path in list( self.setup['sources'].items() ): # create source object abs_path = op.join( self.sources_dir, subset, track_filename, rel_path ) if os.path.exists(abs_path): sources[src] = Source( name=src, path=abs_path ) track.sources = sources # add targets to track targets = collections.OrderedDict() for name, target_srcs in list( self.setup['targets'].items() ): # add a list of target sources target_sources = [] for source, gain in list(target_srcs.items()): if source in list(track.sources.keys()): # add gain to source tracks track.sources[source].gain = float(gain) # add tracks to components target_sources.append(sources[source]) # add sources to target if target_sources: targets[name] = Target(sources=target_sources) # add targets to track track.targets = targets # add track to list of tracks tracks.append(track) # Filter tracks by valid_ids if self.valid_ids is not None: if subset == 'Dev': tracks = [t for t in tracks if t.id not in self.valid_ids] if subset == 'Valid': tracks = [t for t in tracks if t.id in self.valid_ids] if ids is not None: return [t for t in tracks if t.id in ids] else: return tracks
def _save_estimates(self, user_estimates, track, estimates_dir): track_estimate_dir = op.join( estimates_dir, track.subset, track.filename ) if not os.path.exists(track_estimate_dir): os.makedirs(track_estimate_dir) # write out tracks to disk for target, estimate in list(user_estimates.items()): target_path = op.join(track_estimate_dir, target + '.wav') sf.write(target_path, estimate, track.rate) pass def _evaluate_estimates(self, user_estimates, track): audio_estimates = [] audio_reference = [] # make sure to always build the list in the same order # therefore track.targets is an OrderedDict labels_references = [] # save the list of targets to be evaluated for target in list(track.targets.keys()): try: # try to fetch the audio from the user_results of a given key estimate = user_estimates[target] # append this target name to the list of labels labels_references.append(target) # add the audio to the list of estimates audio_estimates.append(estimate) # add the audio to the list of references audio_reference.append(track.targets[target].audio) except KeyError: pass if audio_estimates and audio_reference: audio_estimates = np.array(audio_estimates) audio_reference = np.array(audio_reference) if audio_estimates.shape == audio_reference.shape: self.evaluator.evaluate( audio_estimates, audio_reference, track.rate )
[docs] def test(self, user_function): """Test the dsdtools processing Parameters ---------- user_function : callable, optional function which separates the mixture into estimates. If no function is provided (default in `None`) estimates are loaded from disk when `evaluate is True`. Raises ------ TypeError If the provided function handle is not callable. ValueError If the output is not compliant to the bsseval methods See Also -------- run : Process the dsdtools """ if not hasattr(user_function, '__call__'): raise TypeError("Please provide a function.") test_track = Track(filename="test") signal = np.random.random((66000, 2)) test_track.audio = signal test_track.rate = 44100 user_results = user_function(test_track) if isinstance(user_results, dict): for target, audio in list(user_results.items()): if target not in self.targets_names: raise ValueError("Target '%s' not supported!" % target) d = audio.dtype if not np.issubdtype(d, float): raise ValueError( "Estimate is not of type numpy.float_" ) if audio.shape != signal.shape: raise ValueError( "Shape of estimate does not match input shape" ) else: raise ValueError("output needs to be a dict") return True
[docs] def evaluate( self, user_function=None, estimates_dir=None, *args, **kwargs ): """Run the dsdtools evaluation shortcut to ``run( user_function=None, estimates_dir=estimates_dir, evaluate=True )`` """ return self.run( user_function=user_function, estimates_dir=estimates_dir, evaluate=True, *args, **kwargs )
def _process_function(self, track, user_function, estimates_dir, evaluate): # load estimates from disk instead of processing if user_function is None: track_estimate_dir = op.join( estimates_dir, track.subset, track.filename ) user_results = {} for target_path in glob.glob(track_estimate_dir + '/*.wav'): target_name = op.splitext( os.path.basename(target_path) )[0] try: target_audio, rate = sf.read( target_path, always_2d=True ) user_results[target_name] = target_audio except RuntimeError: pass else: # call the user provided function user_results = user_function(track) if estimates_dir and not evaluate and user_function is not None: self._save_estimates(user_results, track, estimates_dir) if evaluate: self._evaluate_estimates(user_results, track)
[docs] def run( self, user_function=None, estimates_dir=None, evaluate=False, subsets=None, ids=None, parallel=False, cpus=4 ): """Run the dsdtools processing Parameters ---------- user_function : callable, optional function which separates the mixture into estimates. If no function is provided (default in `None`) estimates are loaded from disk when `evaluate is True`. estimates_dir : str, optional path to the user provided estimates. Directory will be created if it does not exist. Default is `none` which means that the results are not saved. evaluate : bool, optional evaluate the estimates by using. Default is False subsets : list[str], optional select a _dsdtools_ subset `Dev` or `Test`. Defaults to both ids : list[int] or int, optional select single or multiple _dsdtools_ items by ID parallel: bool, optional activate multiprocessing cpus: int, optional set number of cores if `parallel` mode is active, defaults to 4 Raises ------ RuntimeError If the provided function handle is not callable. See Also -------- test : Test the user provided function """ if user_function is None and estimates_dir and evaluate is None: raise RuntimeError("Provide a function or use evaluate feature!") try: ids = int(os.environ['DSD_ID']) except KeyError: pass # list of tracks to be processed tracks = self.load_dsd_tracks(subsets=subsets, ids=ids) success = False if parallel: pool = multiprocessing.Pool(cpus, initializer=init_worker) success = list( tqdm.tqdm( pool.imap_unordered( func=functools.partial( process_function_alias, self, user_function=user_function, estimates_dir=estimates_dir, evaluate=evaluate ), iterable=tracks, chunksize=1 ), total=len(tracks) ) ) pool.close() pool.join() else: success = list( tqdm.tqdm( map( lambda x: self._process_function( x, user_function, estimates_dir, evaluate ), tracks ), total=len(tracks) ) ) return success
[docs]def process_function_alias(obj, *args, **kwargs): return obj._process_function(*args, **kwargs)
[docs]def init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)