#!/usr/bin/env python3 import os import sys import glob import importlib import numpy as np import librosa from argparse import ArgumentParser from itertools import permutations def main(): args = parse_args() input_paths = sum((glob.glob(f) for f in args.input_audio), []) if not importlib.util.find_spec('chimeranet'): print('ChimeraNet is not installed, import from source.') sys.path.append(os.path.join(os.path.split(__file__)[0], '..')) from chimeranet.models import probe_model_shape, load_model args.n_frames, args.n_mels, args.n_channels, args.d_embedding\ = probe_model_shape(args.model_path) if len(args.channel_name) < args.n_channels: raise ValueError # short channel names. args.model = load_model(args.model_path) if 0 < args.n_mels < args.n_fft // 2 + 1: args.mel_basis = librosa.filters.mel( args.sr, args.n_fft, args.n_mels, norm=None ) for input_path in input_paths: part(input_path, **args.__dict__) def part(input_path, **kwargs): print('processing {}'.format(input_path)) from chimeranet import from_embedding, from_mask from chimeranet import split_window from chimeranet import merge_windows_mean, merge_windows_most_common if kwargs['plot_spectrograms']: import matplotlib.pyplot as plt # load audio and split into windows audio, _ = librosa.core.load( input_path, sr=kwargs['sr'], duration=kwargs['duration']) spec, phase = librosa.core.magphase( librosa.core.stft(audio, kwargs['n_fft'], kwargs['hop_length']) ) if 0 < kwargs['n_mels'] < kwargs['n_fft'] // 2 + 1: spec = np.dot(kwargs['mel_basis'], spec) mask_embd = np.empty((kwargs['n_channels'], kwargs['n_mels'], 0)) mask_mask = np.empty((kwargs['n_channels'], kwargs['n_mels'], 0)) n_batch \ = int(np.ceil( spec.shape[1] / kwargs['n_frames'] / kwargs['batch_size'] )) if kwargs['batch_size'] > 0 else 0 for batch_i in range(max(n_batch, 1)): if n_batch == 0: s = spec else: window_size = kwargs['batch_size']*kwargs['n_frames'] s = spec[:, batch_i*window_size:(batch_i+1)*window_size] if s.shape[1] < kwargs['n_frames']: s = np.hstack( (s, np.zeros((s.shape[0], kwargs['n_frames']-s.shape[1]))) ) x = split_window(s, kwargs['n_frames']).transpose((0, 2, 1)) # predict embedding, mask = kwargs['model'].predict(x) y_embd = from_embedding( embedding, kwargs['n_channels'], n_jobs=kwargs['jobs']) y_mask = from_mask(mask) mini_mask_embd = merge_windows_most_common(y_embd)[:, :, :s.shape[1]] mini_mask_mask = merge_windows_mean(y_mask)[:, :, :s.shape[1]] ordered_mini_mask_embd = min( (t for t in permutations(mini_mask_embd)), key=lambda t: np.sum(np.abs(np.array(t)-mini_mask_mask)*s) ) mask_embd = np.dstack((mask_embd, ordered_mini_mask_embd)) mask_mask = np.dstack((mask_mask, mini_mask_mask)) mask_embd = mask_embd[:, :, :spec.shape[1]] mask_mask = mask_mask[:, :, :spec.shape[1]] # reconstruct from prediction if kwargs['plot_spectrograms']: fig = plt.figure(figsize=(30, 15)) nrow = 2*kwargs['n_inference']+1 ncol = kwargs['n_channels'] ax = fig.add_subplot(nrow, ncol, 1) ax.title.set_text('spec. of '+input_path) ax.imshow(spec, origin='lower', aspect='auto') i_inference = 0 if not kwargs['disable_mask_output']: for ci, mask in enumerate(mask_mask): save_audio( input_path, mask, spec, phase, i_inference=i_inference, i_channel=ci, **kwargs ) if kwargs['plot_spectrograms']: plot_spec( input_path, mask, spec, fig, i_inference=i_inference, i_channel=ci, **kwargs ) i_inference += 1 if not kwargs['disable_embedding_inference']: # to align embedding to mask mask_embd_ordered = min( (t for t in permutations(mask_embd)), key=lambda t: np.sum(np.abs(np.array(t)-mask_mask)*spec) ) for ci, mask in enumerate(mask_embd_ordered): save_audio( input_path, mask, spec, phase, i_inference=i_inference, i_channel=ci, **kwargs ) if kwargs['plot_spectrograms']: plot_spec( input_path, mask, spec, fig, i_inference=i_inference, i_channel=ci, **kwargs ) i_inference += 1 if kwargs['plot_spectrograms']: output_plot_path = kwargs['output_plot_mapper'](input_path) output_plot_dir = os.path.dirname(output_plot_path) if not os.path.exists(output_plot_dir): os.makedirs(output_plot_dir) elif not os.path.isdir(output_plot_dir): print('warning: {} will not be created.'.format(output_plot_dir)) # TODO if os.path.isdir(output_plot_dir): plt.savefig(output_plot_path) def plot_spec(input_path, mask, spec, fig, **kwargs): pred_spec = mask * spec output_audio_path = kwargs['output_audio_mapper']( input_path, kwargs['i_inference'], kwargs['i_channel'] ) nrow = 2*kwargs['n_inference']+1 ncol = kwargs['n_channels'] ax = fig.add_subplot( nrow, ncol, (2*kwargs['i_inference']+1)*ncol+kwargs['i_channel']+1 ) ax.title.set_text('mask of ' + output_audio_path) ax.imshow(mask, origin='lower', aspect='auto') ax = fig.add_subplot( nrow, ncol, (2*kwargs['i_inference']+2)*ncol+kwargs['i_channel']+1 ) ax.title.set_text('spec. of ' + output_audio_path) ax.imshow(pred_spec, origin='lower', aspect='auto') def save_audio(input_path, mask, spec, phase, **kwargs): pred_spec = mask * spec if 0 < kwargs['n_mels'] < kwargs['n_fft'] // 2 + 1: pred_spec = np.dot(kwargs['mel_basis'].T, pred_spec) output_audio_path = kwargs['output_audio_mapper']( input_path, kwargs['i_inference'], kwargs['i_channel'] ) output_audio_dir = os.path.dirname(output_audio_path) if not os.path.exists(output_audio_dir): os.makedirs(output_audio_dir) elif not os.path.isdir(output_audio_dir): print('warning: {} will not be created.'.format(output_audio_dir)) # TODO if os.path.isdir(output_audio_dir): out_audio = librosa.core.istft(pred_spec*phase, kwargs['hop_length']) librosa.output.write_wav(output_audio_path, out_audio, kwargs['sr']) def parse_args(): parser = ArgumentParser() # basic arguments basic_group = parser.add_argument_group(title='basic arguments') basic_group.add_argument( '-m', '--model-path', type=str, metavar='PATH', required=True, ) basic_group.add_argument( '-i', '--input-audio', type=str, nargs='*', metavar='PATH', required=True, ) basic_group.add_argument( '-o', '--output-audio', type=str, metavar='FORMATTED STRING', default='{input:}_{infer:}_{channel:}.wav', help='''Formatted string as output audio path (e.g. "{input:}_{infer:}_{channel:}.wav" (default)).''' ) basic_group.add_argument( '-d', '--output-directory', type=str, metavar='DIR', help='If specified, add it as top directory of "--output-audio"' ) basic_group.add_argument( '--batch-size', type=int, metavar='N', default=0, help='Batch size on separation' ) # audio arguments audio_group = parser.add_argument_group(title='audio arguments') audio_group.add_argument( '--sr', type=int, default=16000, metavar='N', help='Sampling rate (default=16000)' ) audio_group.add_argument( '--n-fft', type=int, default=512, metavar='F', help='FFT window size (default=512)' ) audio_group.add_argument( '--hop-length', type=int, default=128, metavar='N', help='Hop length on STFT (default=128)' ) audio_group.add_argument( '--duration', type=float, default=0., metavar='T', help='Audio duration in seconds' ) # advanced output arguments advanced_output_group = parser.add_argument_group(title='advanced output') advanced_output_group.add_argument( '--replace-top-directory', type=str, metavar='DIR', help='If specified, replace top directory of "--output-audio" with it' ) advanced_output_group.add_argument( '--plot-spectrograms', action='store_true', help='Enable output spectrograms' ) advanced_output_group.add_argument( '--disable-embedding-inference', action='store_true', help='Disable embedding inference' ) advanced_output_group.add_argument( '--disable-mask-output', action='store_true', help='Disable output from mask inference' ) advanced_output_group.add_argument( '--output-plot', type=str, metavar='FORMATTED STRING', default='{input:}.png', help='''Formatted string as output spectrogram plot prefix (e.g. "{input:}.png").''' ) advanced_output_group.add_argument( '--channel-name', type=str, nargs='*', metavar='NAME', help='Channel names show on output audio and/or plot.' ) advanced_output_group.add_argument( '--embedding-inference-name', type=str, metavar='NAME', default='embd', help='Inference name of embedding', ) advanced_output_group.add_argument( '--mask-inference-name', type=str, metavar='NAME', default='mask', help='Inference name of embedding', ) advanced_output_group.add_argument( '-j', '--jobs', type=int, metavar='N', default=1, help='The number of jobs of k-means clustering', ) args = parser.parse_args() if not args.duration: args.duration = None args.inference_name = [ args.mask_inference_name, args.embedding_inference_name, ] if args.disable_embedding_inference: args.inference_names.pop(1) if args.disable_mask_output: args.inference_names.pop(0) args.n_inference = len(args.inference_name) if args.channel_name is None: class ChannelNameMapper(object): def __getitem__(self, key): return 'ch{key:}'.format(key=key+1) def __len__(self): return 10000 # sufficiently long list args.channel_name = ChannelNameMapper() try: args.output_audio.format( input=os.path.splitext(args.input_audio[0])[0], infer=args.mask_inference_name, channel=args.channel_name[0], ) except KeyError: parser.error( '"--output-audio" must not take other than ' '"input", "infer" and "channel" as key.' ) if args.output_directory and args.replace_top_directory: parser.error( '"--output-directory" and "--replace-top-directory" are ' 'mutually exclusive.' ) def output_audio_mapper(input, i_infer, i_channel): output_path = args.output_audio.format( input=os.path.splitext(input)[0], infer=args.inference_name[i_infer], channel=args.channel_name[i_channel] ) if args.output_directory: output_path = os.path.join(args.output_directory, output_path) if args.replace_top_directory: output_path = os.path.join( args.replace_top_directory, *output_path.split(os.path.sep)[1:] ) return output_path args.output_audio_mapper = output_audio_mapper if args.plot_spectrograms: if not importlib.util.find_spec('matplotlib'): parser.error( '"--plot-spectrogram": matplotlib is not installed.' ) try: args.output_plot.format( input=os.path.splitext(args.input_audio[0])[0], ) except KeyError: parser.error( '"--output-plot" must not take other than "input" as key.' ) def output_plot_mapper(input): output_path = args.output_plot.format( input=os.path.splitext(input)[0] ) if args.output_directory: output_path = os.path.join(args.output_directory, output_path) if args.replace_top_directory: output_path = os.path.join( args.replace_top_directory, *output_path.split(os.path.sep)[1:] ) return output_path args.output_plot_mapper = output_plot_mapper return args if __name__ == '__main__': main()