For more information about the examples, such as how the Python and Mojo files interact with each other, see the Examples Overview
Classifier¶
Python Code¶
import sys
from pathlib import Path
import argparse
sys.path.insert(0, str(Path(__file__).parent.parent))
from mmm_python import *
def main():
parser = argparse.ArgumentParser(description="Run the MMMAudio Classifier example.")
parser.add_argument("--src", type=str, help="Source audio file to classify", required=False)
args = parser.parse_args()
# outdevice = 'BlackHole 2ch'
outdevice = 'default'
mmm_audio = MMMAudio(in_device=None, out_device=outdevice, blocksize=512, graph_name="Classifier", package_name="examples")
if args.src:
mmm_audio.send_string("src", args.src)
mmm_audio.start_audio()
if __name__ == "__main__":
main()
Mojo Code¶
from mmm_audio import *
comptime scaler_path = "examples/nn_trainings/mfcc_classifier_scaler.joblib"
comptime model_path = "examples/nn_trainings/mfcc_classifier_traced.pt"
comptime windowsize = 1024
comptime hopsize = windowsize // 2
comptime n_mfcc = 13
struct ClassifierWindow(FFTProcessable):
var model: PythonObject
var scaler: StandardScaler
var mfcc: MFCC
var scaled_coeffs: List[Float64]
var py_input: PythonObject
var py_output: PythonObject
def __init__(out self, sr: Float64):
self.scaler = StandardScaler(scaler_path)
self.mfcc = MFCC(sr=sr, fft_size=windowsize, num_coeffs=n_mfcc)
self.scaled_coeffs = List[Float64](fill=0.0, length=n_mfcc)
try:
torch = Python.import_module("torch")
self.model = torch.jit.load(model_path)
self.py_input = torch.zeros(n_mfcc)
self.py_output = torch.zeros(1) # Adjust the size based on your model's output
except e:
abort("Error loading PyTorch model: " + String(e))
def next_frame(mut self, mut mags: List[Float64], mut phss: List[Float64]):
self.mfcc.from_mags(mags)
self.scaler.transform_point(self.mfcc.coeffs, self.scaled_coeffs)
try:
for i in range(n_mfcc):
self.py_input[i] = self.scaled_coeffs[i]
self.py_output = self.model(self.py_input)
o = Float64(py=self.py_output.item())
display: String = "🐶" if o > 0.5 else "❌"
print("Dog:",display,"---", o)
except e:
abort("Error predicting: " + String(e))
struct Classifier(Movable,Copyable):
var world: World
var fftp: FFTProcess[ClassifierWindow,output_window_shape=WindowType.hann]
var src: Buffer
var player: Play
var src_path: String
var m: Messenger
def __init__(out self, world: World):
self.world = world
self.src_path = "/Users/ted/Desktop/dog-dataset/Media/Tremblay-BaB-SoundscapeGolcarWithDog.wav"
self.fftp = FFTProcess[ClassifierWindow](self.world, ClassifierWindow(self.world[].sample_rate), windowsize, hopsize)
self.src = Buffer.load(self.src_path)
self.player = Play(self.world)
self.m = Messenger(self.world)
def next(mut self) -> MFloat[2]:
if self.m.notify_update("src_path", self.src_path):
self.src = Buffer.load(self.src_path)
src = self.player.next(self.src)
_ = self.fftp.next(src)
return src