You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.6 KiB
Python

#!/usr/bin/env python
import numpy as np
import soundfile as sf
from sklearn.cluster import KMeans
from scipy.signal import butter, lfilter
def highpass_filter(audio, sr, cutoff=20.0):
b, a = butter(1, cutoff / (sr / 2), btype='highpass')
return lfilter(b, a, audio)
def find_upward_zero_crossings(signal):
return np.where((signal[:-1] < 0) & (signal[1:] >= 0))[0] + 1
def extract_wavesets(signal):
zc = find_upward_zero_crossings(signal)
return [signal[zc[i]:zc[i+1]] for i in range(len(zc) - 1)], zc
def compute_features(wavesets):
lengths = np.array([len(w) for w in wavesets])
rms = np.array([np.sqrt(np.mean(w**2)) for w in wavesets])
return lengths, rms
def normalize_and_weight(lengths, rms, w):
lengths = (lengths - np.mean(lengths)) / np.std(lengths)
rms = (rms - np.mean(rms)) / np.std(rms)
lengths *= w
return np.stack([lengths, rms], axis=1)
def replace_with_representatives(wavesets, labels, centroids, features):
reps = []
for k in range(centroids.shape[0]):
cluster_indices = np.where(labels == k)[0]
cluster_features = features[cluster_indices]
dists = np.linalg.norm(cluster_features - centroids[k], axis=1)
rep_idx = cluster_indices[np.argmin(dists)]
reps.append(wavesets[rep_idx])
return [reps[label] for label in labels]
def reconstruct_signal(replaced_sets, zero_crossings, length):
output = np.zeros(length)
cursor = zero_crossings[0]
for i, ws in enumerate(replaced_sets):
end = cursor + len(ws)
if end <= len(output):
output[cursor:end] = ws
cursor = end
else:
break
return output
def waveset_clustering_effect(filepath, output_path, w=5, clusters_per_sec=20):
signal, sr = sf.read(filepath)
if signal.ndim > 1:
signal = signal.mean(axis=1) # Mono
signal = highpass_filter(signal, sr)
wavesets, zero_crossings = extract_wavesets(signal)
lengths, rms = compute_features(wavesets)
features = normalize_and_weight(lengths, rms, w)
total_time = len(signal) / sr
k = int(clusters_per_sec * total_time)
k = max(2, min(k, len(wavesets))) # Avoid trivial or impossible cases
kmeans = KMeans(n_clusters=k, random_state=0).fit(features)
replaced_sets = replace_with_representatives(wavesets, kmeans.labels_, kmeans.cluster_centers_, features)
reconstructed = reconstruct_signal(replaced_sets, zero_crossings, len(signal))
sf.write(output_path, reconstructed, sr)
# Example usage:
waveset_clustering_effect("input.wav", "output.wav", w=5, clusters_per_sec=15)