Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ja/hub/tutorials/spice.ipynb
25118 views
Kernel: Python 3

Licensed under the Apache License, Version 2.0 (the "License");

#@title Copyright 2020 The TensorFlow Hub Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ==============================================================================

SPICE によるピッチ検出

この Colab では、TensorFlow Hub からダウンロードした SPICE モデルの使用方法を紹介します。

!sudo apt-get install -q -y timidity libsndfile1
# All the imports to deal with sound data !pip install pydub librosa music21
import tensorflow as tf import tensorflow_hub as hub import numpy as np import matplotlib.pyplot as plt import librosa from librosa import display as librosadisplay import logging import math import statistics import sys from IPython.display import Audio, Javascript from scipy.io import wavfile from base64 import b64decode import music21 from pydub import AudioSegment logger = logging.getLogger() logger.setLevel(logging.ERROR) print("tensorflow: %s" % tf.__version__) #print("librosa: %s" % librosa.__version__)

音声入力ファイル

これが最も困難な部分です。あなたの歌声を録音しましょう!😃

音声ファイルの取得には、次の 4 つの方法があります。

  1. Colab で直接音声を録音する

  2. ご利用の PC からアップロードする

  3. Google Drive に保存されたファイルを使用する

  4. ウェブからファイルをダウンロードする

以下の 4 つの方法から 1 つを選択してください。

#@title [Run this] Definition of the JS code to record audio straight from the browser RECORD = """ const sleep = time => new Promise(resolve => setTimeout(resolve, time)) const b2text = blob => new Promise(resolve => { const reader = new FileReader() reader.onloadend = e => resolve(e.srcElement.result) reader.readAsDataURL(blob) }) var record = time => new Promise(async resolve => { stream = await navigator.mediaDevices.getUserMedia({ audio: true }) recorder = new MediaRecorder(stream) chunks = [] recorder.ondataavailable = e => chunks.push(e.data) recorder.start() await sleep(time) recorder.onstop = async ()=>{ blob = new Blob(chunks) text = await b2text(blob) resolve(text) } recorder.stop() }) """ def record(sec=5): try: from google.colab import output except ImportError: print('No possible to import output from google.colab') return '' else: print('Recording') display(Javascript(RECORD)) s = output.eval_js('record(%d)' % (sec*1000)) fname = 'recorded_audio.wav' print('Saving to', fname) b = b64decode(s.split(',')[1]) with open(fname, 'wb') as f: f.write(b) return fname
#@title Select how to input your audio { run: "auto" } INPUT_SOURCE = 'https://storage.googleapis.com/download.tensorflow.org/data/c-scale-metronome.wav' #@param ["https://storage.googleapis.com/download.tensorflow.org/data/c-scale-metronome.wav", "RECORD", "UPLOAD", "./drive/My Drive/YOUR_MUSIC_FILE.wav"] {allow-input: true} print('You selected', INPUT_SOURCE) if INPUT_SOURCE == 'RECORD': uploaded_file_name = record(5) elif INPUT_SOURCE == 'UPLOAD': try: from google.colab import files except ImportError: print("ImportError: files from google.colab seems to not be available") else: uploaded = files.upload() for fn in uploaded.keys(): print('User uploaded file "{name}" with length {length} bytes'.format( name=fn, length=len(uploaded[fn]))) uploaded_file_name = next(iter(uploaded)) print('Uploaded file: ' + uploaded_file_name) elif INPUT_SOURCE.startswith('./drive/'): try: from google.colab import drive except ImportError: print("ImportError: files from google.colab seems to not be available") else: drive.mount('/content/drive') # don't forget to change the name of the file you # will you here! gdrive_audio_file = 'YOUR_MUSIC_FILE.wav' uploaded_file_name = INPUT_SOURCE elif INPUT_SOURCE.startswith('http'): !wget --no-check-certificate 'https://storage.googleapis.com/download.tensorflow.org/data/c-scale-metronome.wav' -O c-scale.wav uploaded_file_name = 'c-scale.wav' else: print('Unrecognized input format!') print('Please select "RECORD", "UPLOAD", or specify a file hosted on Google Drive or a file from the web to download file to download')

音声データを準備する

音声データを入手したので、期待される形式に変換して聴いてみましょう!

SPICE モデルでは、入力として、サンプリングレート 16 kHz の音声ファイルが必要です。また、チャンネル数は 1 つ(Mono)である必要があります。

この部分の作業を支援するために、wav ファイルをモデルが期待する形式に変換する関数(convert_audio_for_model)を用意しました。

# Function that converts the user-created audio to the format that the model # expects: bitrate 16kHz and only one channel (mono). EXPECTED_SAMPLE_RATE = 16000 def convert_audio_for_model(user_file, output_file='converted_audio_file.wav'): audio = AudioSegment.from_file(user_file) audio = audio.set_frame_rate(EXPECTED_SAMPLE_RATE).set_channels(1) audio.export(output_file, format="wav") return output_file
# Converting to the expected format for the model # in all the input 4 input method before, the uploaded file name is at # the variable uploaded_file_name converted_audio_file = convert_audio_for_model(uploaded_file_name)
# Loading audio samples from the wav file: sample_rate, audio_samples = wavfile.read(converted_audio_file, 'rb') # Show some basic information about the audio. duration = len(audio_samples)/sample_rate print(f'Sample rate: {sample_rate} Hz') print(f'Total duration: {duration:.2f}s') print(f'Size of the input: {len(audio_samples)}') # Let's listen to the wav file. Audio(audio_samples, rate=sample_rate)

まずはじめに、歌声の波形を見てみましょう。

# We can visualize the audio as a waveform. _ = plt.plot(audio_samples)

より詳しいビジュアライゼーションとして、経時的に周波数を示すスペクトログラム があります。

ここでは、対数による周波数スケールを使用して、歌声をより鮮明に視覚化させます。

MAX_ABS_INT16 = 32768.0 def plot_stft(x, sample_rate, show_black_and_white=False): x_stft = np.abs(librosa.stft(x, n_fft=2048)) fig, ax = plt.subplots() fig.set_size_inches(20, 10) x_stft_db = librosa.amplitude_to_db(x_stft, ref=np.max) if(show_black_and_white): librosadisplay.specshow(data=x_stft_db, y_axis='log', sr=sample_rate, cmap='gray_r') else: librosadisplay.specshow(data=x_stft_db, y_axis='log', sr=sample_rate) plt.colorbar(format='%+2.0f dB') plot_stft(audio_samples / MAX_ABS_INT16 , sample_rate=EXPECTED_SAMPLE_RATE) plt.show()

ここで、最後の変換を行う必要があります。音声サンプルは int16 形式です。これらを -1 と 1 の間の浮動小数点数に正規化する必要があります。

audio_samples = audio_samples / float(MAX_ABS_INT16)

モデルを実行する

ようやく簡単な作業です。TensorFlow Hub でモデルを読み込み、音声をフィードしましょう。SPICE から ピッチと不確実性の 2 つの出力が得られます。

TensorFlow Hub は、機械学習モデルの再利用可能な部分の公開、発見、および消費のためのライブラリです。ユーザーの抱える課題の解決する機械学習の使用を簡単にすることができます。

モデルを読み込むには、Hub モジュールと、モデルにポイントする URL のみが必要です。

# Loading the SPICE model is easy: model = hub.load("https://tfhub.dev/google/spice/2")

注意: ここでの豆知識は、Hub のすべてのモデル URL は、ダウンロードだけでなく、ドキュメントの読み取りにも利用できるということです。そのため、ブラウザでそのリンクを開くと、モデルの使用方法を読み、どのようにしてトレーニングされたのかという詳細も知ることができます。

モデルが読み込まれ、データの準備が完了したので、結果を取得するための 3 行を追加しましょう。

# We now feed the audio to the SPICE tf.hub model to obtain pitch and uncertainty outputs as tensors. model_output = model.signatures["serving_default"](tf.constant(audio_samples, tf.float32)) pitch_outputs = model_output["pitch"] uncertainty_outputs = model_output["uncertainty"] # 'Uncertainty' basically means the inverse of confidence. confidence_outputs = 1.0 - uncertainty_outputs fig, ax = plt.subplots() fig.set_size_inches(20, 10) plt.plot(pitch_outputs, label='pitch') plt.plot(confidence_outputs, label='confidence') plt.legend(loc="lower right") plt.show()

信頼度の低い(confidence < 0.9)すべてのピッチ推定値を取り除いて、残りのピッチをグラフ化して、結果を理解しやすくしましょう。

confidence_outputs = list(confidence_outputs) pitch_outputs = [ float(x) for x in pitch_outputs] indices = range(len (pitch_outputs)) confident_pitch_outputs = [ (i,p) for i, p, c in zip(indices, pitch_outputs, confidence_outputs) if c >= 0.9 ] confident_pitch_outputs_x, confident_pitch_outputs_y = zip(*confident_pitch_outputs) fig, ax = plt.subplots() fig.set_size_inches(20, 10) ax.set_ylim([0, 1]) plt.scatter(confident_pitch_outputs_x, confident_pitch_outputs_y, ) plt.scatter(confident_pitch_outputs_x, confident_pitch_outputs_y, c="r") plt.show()

SPICE が返すピッチ値の範囲は 0 から 1 です。この値を Hz 単位の絶対ピッチ値に変換しましょう。

def output2hz(pitch_output): # Constants taken from https://tfhub.dev/google/spice/2 PT_OFFSET = 25.58 PT_SLOPE = 63.07 FMIN = 10.0; BINS_PER_OCTAVE = 12.0; cqt_bin = pitch_output * PT_SLOPE + PT_OFFSET; return FMIN * 2.0 ** (1.0 * cqt_bin / BINS_PER_OCTAVE) confident_pitch_values_hz = [ output2hz(p) for p in confident_pitch_outputs_y ]

では、どれほど良い予測が得られるか確認しましょう。予測されるピッチを元のスペクトログラムにオーバーレイ表示します。ピッチ予測をより見やすくするために、スペクトログラムをモノクロに変更してます。

plot_stft(audio_samples / MAX_ABS_INT16 , sample_rate=EXPECTED_SAMPLE_RATE, show_black_and_white=True) # Note: conveniently, since the plot is in log scale, the pitch outputs # also get converted to the log scale automatically by matplotlib. plt.scatter(confident_pitch_outputs_x, confident_pitch_values_hz, c="r") plt.show()

音符に変換する

ピッチ値を得たので、次はこの値を音符に変換することにしましょう!この作業はそれ自体が困難となる部分です。次の 2 つの項目を考慮する必要があります。

  1. 休符(歌声がないところ)

  2. 各音符のサイズ(オフセット)

1: 出力にゼロを追加して、歌声がない部分を示す

pitch_outputs_and_rests = [ output2hz(p) if c >= 0.9 else 0 for i, p, c in zip(indices, pitch_outputs, confidence_outputs) ]

2: 音符のオフセットを追加する

自由に歌う場合、メロディーには音符が表現する絶対ピッチ値に対するオフセットがある場合があります。したがって、予測を音符に変換するには、この潜在的なオフセットを修正する必要があります。次のコードはこれを計算しています。

A4 = 440 C0 = A4 * pow(2, -4.75) note_names = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] def hz2offset(freq): # This measures the quantization error for a single note. if freq == 0: # Rests always have zero error. return None # Quantized note. h = round(12 * math.log2(freq / C0)) return 12 * math.log2(freq / C0) - h # The ideal offset is the mean quantization error for all the notes # (excluding rests): offsets = [hz2offset(p) for p in pitch_outputs_and_rests if p != 0] print("offsets: ", offsets) ideal_offset = statistics.mean(offsets) print("ideal offset: ", ideal_offset)

ヒューリスティックを使用して歌われた可能性の最も高い音符のシーケンスの予測を試みることができるようになりました。上記で計算された理想的なオフセットは 1 つの材料ではありますが、速度(8 個など、どれくらいの予測をたてるのか)と量子化を始める時間オフセットを知る必要もあります。単純にしておくために、異なる速度と時間オフセットを試して、量子化誤差を測定し、最終的に、この誤差を最小限に抑える値を使用します。

def quantize_predictions(group, ideal_offset): # Group values are either 0, or a pitch in Hz. non_zero_values = [v for v in group if v != 0] zero_values_count = len(group) - len(non_zero_values) # Create a rest if 80% is silent, otherwise create a note. if zero_values_count > 0.8 * len(group): # Interpret as a rest. Count each dropped note as an error, weighted a bit # worse than a badly sung note (which would 'cost' 0.5). return 0.51 * len(non_zero_values), "Rest" else: # Interpret as note, estimating as mean of non-rest predictions. h = round( statistics.mean([ 12 * math.log2(freq / C0) - ideal_offset for freq in non_zero_values ])) octave = h // 12 n = h % 12 note = note_names[n] + str(octave) # Quantization error is the total difference from the quantized note. error = sum([ abs(12 * math.log2(freq / C0) - ideal_offset - h) for freq in non_zero_values ]) return error, note def get_quantization_and_error(pitch_outputs_and_rests, predictions_per_eighth, prediction_start_offset, ideal_offset): # Apply the start offset - we can just add the offset as rests. pitch_outputs_and_rests = [0] * prediction_start_offset + \ pitch_outputs_and_rests # Collect the predictions for each note (or rest). groups = [ pitch_outputs_and_rests[i:i + predictions_per_eighth] for i in range(0, len(pitch_outputs_and_rests), predictions_per_eighth) ] quantization_error = 0 notes_and_rests = [] for group in groups: error, note_or_rest = quantize_predictions(group, ideal_offset) quantization_error += error notes_and_rests.append(note_or_rest) return quantization_error, notes_and_rests best_error = float("inf") best_notes_and_rests = None best_predictions_per_note = None for predictions_per_note in range(20, 65, 1): for prediction_start_offset in range(predictions_per_note): error, notes_and_rests = get_quantization_and_error( pitch_outputs_and_rests, predictions_per_note, prediction_start_offset, ideal_offset) if error < best_error: best_error = error best_notes_and_rests = notes_and_rests best_predictions_per_note = predictions_per_note # At this point, best_notes_and_rests contains the best quantization. # Since we don't need to have rests at the beginning, let's remove these: while best_notes_and_rests[0] == 'Rest': best_notes_and_rests = best_notes_and_rests[1:] # Also remove silence at the end. while best_notes_and_rests[-1] == 'Rest': best_notes_and_rests = best_notes_and_rests[:-1]

では、量子化された音符を楽譜として書き出しましょう!

これを行うには、music21Open Sheet Music Display の 2 つのライブラリを使用します。

注意: 単純に行えるように、ここではすべての音符の長さが同じ(半音)であると仮定しています。

# Creating the sheet music score. sc = music21.stream.Score() # Adjust the speed to match the actual singing. bpm = 60 * 60 / best_predictions_per_note print ('bpm: ', bpm) a = music21.tempo.MetronomeMark(number=bpm) sc.insert(0,a) for snote in best_notes_and_rests: d = 'half' if snote == 'Rest': sc.append(music21.note.Rest(type=d)) else: sc.append(music21.note.Note(snote, type=d))
#@title [Run this] Helper function to use Open Sheet Music Display (JS code) to show a music score from IPython.core.display import display, HTML, Javascript import json, random def showScore(score): xml = open(score.write('musicxml')).read() showMusicXML(xml) def showMusicXML(xml): DIV_ID = "OSMD_div" display(HTML('<div id="'+DIV_ID+'">loading OpenSheetMusicDisplay</div>')) script = """ var div_id = %%DIV_ID%%; function loadOSMD() { return new Promise(function(resolve, reject){ if (window.opensheetmusicdisplay) { return resolve(window.opensheetmusicdisplay) } // OSMD script has a 'define' call which conflicts with requirejs var _define = window.define // save the define object window.define = undefined // now the loaded script will ignore requirejs var s = document.createElement( 'script' ); s.setAttribute( 'src', "https://cdn.jsdelivr.net/npm/[email protected]/build/opensheetmusicdisplay.min.js" ); //s.setAttribute( 'src', "/custom/opensheetmusicdisplay.js" ); s.onload=function(){ window.define = _define resolve(opensheetmusicdisplay); }; document.body.appendChild( s ); // browser will try to load the new script tag }) } loadOSMD().then((OSMD)=>{ window.openSheetMusicDisplay = new OSMD.OpenSheetMusicDisplay(div_id, { drawingParameters: "compacttight" }); openSheetMusicDisplay .load(%%data%%) .then( function() { openSheetMusicDisplay.render(); } ); }) """.replace('%%DIV_ID%%',DIV_ID).replace('%%data%%',json.dumps(xml)) display(Javascript(script)) return
# rendering the music score showScore(sc) print(best_notes_and_rests)

音符を MIDI ファイルに変換して、聴いてみましょう。

このファイルを作成するには、前に作成したストリームを使用できます。

# Saving the recognized musical notes as a MIDI file converted_audio_file_as_midi = converted_audio_file[:-4] + '.mid' fp = sc.write('midi', fp=converted_audio_file_as_midi)
wav_from_created_midi = converted_audio_file_as_midi.replace(' ', '_') + "_midioutput.wav" print(wav_from_created_midi)

Colab で聴くには、wav に変換し直す必要があります。簡単な方法は、Timidty を使用することです。

!timidity $converted_audio_file_as_midi -Ow -o $wav_from_created_midi

最後に、モデルが推論し、予測されたピッチから MIDI 経由で作成され、音符から作成された音声を聴いてみましょう!

Audio(wav_from_created_midi)