Path: blob/master/examples/tacotron2/export_align.py
1558 views
import os1import shutil2from tqdm import tqdm3import argparse45from scipy.ndimage import zoom6from skimage.data import camera7import numpy as np8from scipy.spatial.distance import cdist91011def safemkdir(dirn):12if not os.path.isdir(dirn):13os.mkdir(dirn)141516from pathlib import Path171819def duration_to_alignment(in_duration):20total_len = np.sum(in_duration)21num_chars = len(in_duration)2223attention = np.zeros(shape=(num_chars, total_len), dtype=np.float32)24y_offset = 02526for duration_idx, duration_val in enumerate(in_duration):27for y_val in range(0, duration_val):28attention[duration_idx][y_offset + y_val] = 1.02930y_offset += duration_val3132return attention333435def rescale_alignment(in_alignment, in_targcharlen):36current_x = in_alignment.shape[0]37x_ratio = in_targcharlen / current_x38pivot_points = []3940zoomed = zoom(in_alignment, (x_ratio, 1.0), mode="nearest")4142for x_v in range(0, zoomed.shape[0]):43for y_v in range(0, zoomed.shape[1]):44val = zoomed[x_v][y_v]45if val < 0.5:46val = 0.047else:48val = 1.049pivot_points.append((x_v, y_v))5051zoomed[x_v][y_v] = val5253if zoomed.shape[0] != in_targcharlen:54print("Zooming didn't rshape well, explicitly reshaping")55zoomed.resize((in_targcharlen, in_alignment.shape[1]))5657return zoomed, pivot_points585960def gather_dist(in_mtr, in_points):61# initialize with known size for fast62full_coords = [(0, 0) for x in range(in_mtr.shape[0] * in_mtr.shape[1])]63i = 064for x in range(0, in_mtr.shape[0]):65for y in range(0, in_mtr.shape[1]):66full_coords[i] = (x, y)67i += 16869return cdist(full_coords, in_points, "euclidean")707172def create_guided(in_align, in_pvt, looseness):73new_att = np.ones(in_align.shape, dtype=np.float32)74# It is dramatically faster that we first gather all the points and calculate than do it manually75# for each point in for loop76dist_arr = gather_dist(in_align, in_pvt)77# Scale looseness based on attention size. (addition works better than mul). Also divide by 10078# because having user input 3.35 is nicer79real_loose = (looseness / 100) * (new_att.shape[0] + new_att.shape[1])80g_idx = 081for x in range(0, new_att.shape[0]):82for y in range(0, new_att.shape[1]):83min_point_idx = dist_arr[g_idx].argmin()8485closest_pvt = in_pvt[min_point_idx]86distance = dist_arr[g_idx][min_point_idx] / real_loose87distance = np.power(distance, 2)8889g_idx += 19091new_att[x, y] = distance9293return np.clip(new_att, 0.0, 1.0)949596def get_pivot_points(in_att):97ret_points = []98for x in range(0, in_att.shape[0]):99for y in range(0, in_att.shape[1]):100if in_att[x, y] > 0.8:101ret_points.append((x, y))102return ret_points103104105def main():106parser = argparse.ArgumentParser(107description="Postprocess durations to become alignments"108)109parser.add_argument(110"--dump-dir",111default="dump",112type=str,113help="Path of dump directory",114)115parser.add_argument(116"--looseness",117default=3.5,118type=float,119help="Looseness of the generated guided attention map. Lower values = tighter",120)121args = parser.parse_args()122dump_dir = args.dump_dir123dump_sets = ["train", "valid"]124125for d_set in dump_sets:126full_fol = os.path.join(dump_dir, d_set)127align_path = os.path.join(full_fol, "alignments")128129ids_path = os.path.join(full_fol, "ids")130durations_path = os.path.join(full_fol, "durations")131132safemkdir(align_path)133134for duration_fn in tqdm(os.listdir(durations_path)):135if not ".npy" in duration_fn:136continue137138id_fn = duration_fn.replace("-durations", "-ids")139140id_path = os.path.join(ids_path, id_fn)141duration_path = os.path.join(durations_path, duration_fn)142143duration_arr = np.load(duration_path)144id_arr = np.load(id_path)145146id_true_size = len(id_arr)147148align = duration_to_alignment(duration_arr)149150if align.shape[0] != id_true_size:151align, points = rescale_alignment(align, id_true_size)152else:153points = get_pivot_points(align)154155if len(points) == 0:156print("WARNING points are empty for", id_fn)157158align = create_guided(align, points, args.looseness)159160align_fn = id_fn.replace("-ids", "-alignment")161align_full_fn = os.path.join(align_path, align_fn)162163np.save(align_full_fn, align.astype("float32"))164165166if __name__ == "__main__":167main()168169170