You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
#!/usr/bin/env python
|
|
import numpy as np
|
|
import soundfile as sf
|
|
from sklearn.cluster import KMeans
|
|
from scipy.signal import butter, lfilter
|
|
|
|
def highpass_filter(audio, sr, cutoff=20.0):
|
|
b, a = butter(1, cutoff / (sr / 2), btype='highpass')
|
|
return lfilter(b, a, audio)
|
|
|
|
def find_upward_zero_crossings(signal):
|
|
return np.where((signal[:-1] < 0) & (signal[1:] >= 0))[0] + 1
|
|
|
|
def extract_wavesets(signal):
|
|
zc = find_upward_zero_crossings(signal)
|
|
return [signal[zc[i]:zc[i+1]] for i in range(len(zc) - 1)], zc
|
|
|
|
def compute_features(wavesets):
|
|
lengths = np.array([len(w) for w in wavesets])
|
|
rms = np.array([np.sqrt(np.mean(w**2)) for w in wavesets])
|
|
return lengths, rms
|
|
|
|
def normalize_and_weight(lengths, rms, w):
|
|
lengths = (lengths - np.mean(lengths)) / np.std(lengths)
|
|
rms = (rms - np.mean(rms)) / np.std(rms)
|
|
lengths *= w
|
|
return np.stack([lengths, rms], axis=1)
|
|
|
|
def replace_with_representatives(wavesets, labels, centroids, features):
|
|
reps = []
|
|
for k in range(centroids.shape[0]):
|
|
cluster_indices = np.where(labels == k)[0]
|
|
cluster_features = features[cluster_indices]
|
|
dists = np.linalg.norm(cluster_features - centroids[k], axis=1)
|
|
rep_idx = cluster_indices[np.argmin(dists)]
|
|
reps.append(wavesets[rep_idx])
|
|
return [reps[label] for label in labels]
|
|
|
|
def reconstruct_signal(replaced_sets, zero_crossings, length):
|
|
output = np.zeros(length)
|
|
cursor = zero_crossings[0]
|
|
for i, ws in enumerate(replaced_sets):
|
|
end = cursor + len(ws)
|
|
if end <= len(output):
|
|
output[cursor:end] = ws
|
|
cursor = end
|
|
else:
|
|
break
|
|
return output
|
|
|
|
def waveset_clustering_effect(filepath, output_path, w=5, clusters_per_sec=20):
|
|
signal, sr = sf.read(filepath)
|
|
if signal.ndim > 1:
|
|
signal = signal.mean(axis=1) # Mono
|
|
|
|
signal = highpass_filter(signal, sr)
|
|
wavesets, zero_crossings = extract_wavesets(signal)
|
|
lengths, rms = compute_features(wavesets)
|
|
features = normalize_and_weight(lengths, rms, w)
|
|
|
|
total_time = len(signal) / sr
|
|
k = int(clusters_per_sec * total_time)
|
|
k = max(2, min(k, len(wavesets))) # Avoid trivial or impossible cases
|
|
|
|
kmeans = KMeans(n_clusters=k, random_state=0).fit(features)
|
|
replaced_sets = replace_with_representatives(wavesets, kmeans.labels_, kmeans.cluster_centers_, features)
|
|
|
|
reconstructed = reconstruct_signal(replaced_sets, zero_crossings, len(signal))
|
|
sf.write(output_path, reconstructed, sr)
|
|
|
|
# Example usage:
|
|
waveset_clustering_effect("input.wav", "output.wav", w=5, clusters_per_sec=15)
|
|
|