Skip to content

For more information about the examples, such as how the Python and Mojo files interact with each other, see the Examples Overview

TorchMlp

This examples uses a PyTorch MLP model to control a 16 parameter synth to play the synth, just execute the top "if" statement and play the synth with the mouse. The X and Y positions control two input parameters to the MLP, which then outputs 16 parameters to control the synth.

You can also train the Multi-Layer Perceptron by creating any number of input/output pairs and making a new training. When training the MLP, temporarily disable the MLP inference so you can set the synth parameters directly.

Python Code

if True:
    from mmm_python import *
    from random import random

    mmm_audio = MMMAudio(128, in_device=None, graph_name="TorchMlp", package_name="examples")

    # this one is a bit intense, so maybe start with a low volume
    mmm_audio.start_audio()

# below is the code to make a new training --------------------------------

# toggle inference off so you can set the synth values directly
mmm_audio.send_bool("mlp1.toggle_inference", True)
mmm_audio.send_bool("mlp1.toggle_inference", False)

# how many outputs does your mlp have?
out_size = 16

# create lists to hold your training data
X_train_list = []
y_train_list = []

def make_setting():
    setting = []
    for _ in range(out_size):
        setting.append(random())
    print("setting =", setting)
    mmm_audio.send_floats("mlp1.fake_model_output", setting)

    return setting

# create an output setting to train on
outputs = make_setting()

# print out what you have so far
for i in range(len(y_train_list)):
    print(f"Element {i}: {X_train_list[i]}")
    print(f"Element {i}: {y_train_list[i]}")

# when you like a setting add an input and output pair
# this is assuming you are training on 4 pairs of data points - you do as many as you like

outputs = make_setting()

X_train_list.append([0,0])
y_train_list.append(outputs)

outputs = make_setting()

X_train_list.append([0,1])
y_train_list.append(outputs)

outputs = make_setting()

X_train_list.append([1,1])
y_train_list.append(outputs)

outputs = make_setting()

X_train_list.append([1,0])
y_train_list.append(outputs)

# once you have filled the X_train_list and y_train_list, train the network on your data

def do_the_training():
    print("training the network")
    learn_rate = 0.001
    epochs = 5000

    layers = [ [ 64, "relu" ], [ 64, "relu" ], [ out_size, "sigmoid" ] ]

    # train the network in a separate thread so the audio thread doesn't get interrupted

    from mmm_audio.MLP_Python import train_nn
    import threading

    target_function = train_nn
    args = (X_train_list, y_train_list, layers, learn_rate, epochs, "examples/nn_trainings/model_traced.pt")

    # Create a Thread object
    training_thread = threading.Thread(target=target_function, args=args)
    training_thread.start()

do_the_training()

# load the new training into the synth
mmm_audio.send_string("mlp1.load_mlp_training", "examples/nn_trainings/model_traced.pt")  

# toggle inference off so you can set the synth values directly
mmm_audio.send_bool("mlp1.toggle_inference", True)

Mojo Code

from mmm_audio import *

from std.sys import simd_width_of

comptime model_out_size = 16  # Define the output size of the model

# THE SYNTH - is imported from TorchSynth.mojo in this directory
struct TorchSynth(Movable, Copyable):
    var world: World  # Pointer to the MMMWorld instance
    var osc1: Osc[1, Interp.sinc, 1]
    var osc2: Osc[1, Interp.sinc, 1]

    var model: MLP[2, model_out_size]  # Instance of the MLP model - 2 inputs, model_out_size outputs
    var lags: Lags[model_out_size]  # A Lags (Lags processed in parallel) for smoothing the model outputs

    var fb: Float64

    var latch1: Latch[]
    var latch2: Latch[]
    var impulse1: Phasor[]
    var impulse2: Phasor[]

    var filt1: SVF[]
    var filt2: SVF[]

    var dc1: DCTrap[]
    var dc2: DCTrap[]

    def __init__(out self, world: World):
        self.world = world
        self.osc1 = Osc[1, Interp.sinc, 1](self.world)
        self.osc2 = Osc[1, Interp.sinc, 1](self.world)

        # load the trained model
        self.model = MLP(self.world,"examples/nn_trainings/model_traced.pt", "mlp1", trig_rate=25.0)

        # Lags is a utility for processing multiple lag lines in parallel
        self.lags = Lags[model_out_size](self.world, 1/25.0)  # Assuming the model updates at 25 Hz

        # create a feedback variable so each of the oscillators can feedback on each sample
        self.fb = 0.0

        self.latch1 = Latch()
        self.latch2 = Latch()
        self.impulse1 = Phasor(self.world)
        self.impulse2 = Phasor(self.world)
        self.filt1 = SVF(self.world)
        self.filt2 = SVF(self.world)
        self.dc1 = DCTrap(self.world)
        self.dc2 = DCTrap(self.world)

    @always_inline
    def next(mut self) -> MFloat[2]:
        self.model.model_input[0] = self.world[].mouse_x
        self.model.model_input[1] = self.world[].mouse_y

        self.model.next()  # Run the model inference

        self.lags.next(self.model.model_output)  # Get the lagged outputs for smoother control

        # uncomment to see the output of the model
        # self.world[].print(self.lags[0], self.lags[1], self.lags[2], self.lags[3], self.lags[4], self.lags[5], self.lags[6], self.lags[7], self.lags[8], self.lags[9], self.lags[10], self.lags[11], self.lags[12], self.lags[13], self.lags[14], self.lags[15])

        # oscillator 1 -----------------------

        var freq1 = linexp(self.lags[0], 0.0, 1.0, 1.0, 3000) + (linlin(self.lags[1], 0.0, 1.0, 2.0, 5000.0) * self.fb)

        # next_interp implements a variable wavetable oscillator between the N provided wave types
        # in this case, we are using 0, 4, 5, 6 - Sine, BandLimited Tri, BL Saw, BL Square
        osc_frac1 = linlin(self.lags[3], 0.0, 1.0, 0.0, 1.0)
        osc1 = self.osc1.next_basic_waveforms(freq1, 0.0, False, [0,1,2,3], osc_frac1)

        # samplerate reduction
        osc1 = self.latch1.next(osc1, self.impulse1.next_bool(linexp(self.lags[4], 0.0, 1.0, 100.0, self.world[].sample_rate*0.5)))
        osc1 = self.filt1.lpf(osc1, linexp(self.lags[5], 0.0, 1.0, 100.0, 20000.0), linlin(self.lags[6], 0.0, 1.0, 0.707, 4.0))

        tanh_gain = linlin(self.lags[7], 0.0, 1.0, 0.5, 10.0)

        # get rid of dc offset
        osc1 = tanh(osc1*tanh_gain)
        osc1 = self.dc1.next(osc1)

        # oscillator 2 -----------------------

        var freq2 = linlin(self.lags[8], 0.0, 1.0, 2.0, 5000.0) + (linlin(self.lags[9], 0.0, 1.0, 2.0, 5000.0) * osc1)

        osc_frac2 = linlin(self.lags[11], 0.0, 1.0, 0.0, 1.0)
        var osc2 = self.osc2.next_basic_waveforms(freq2, 0.0, False, [0,1,2,3], osc_frac2)

        osc2 = self.latch2.next(osc2, self.impulse2.next_bool(linexp(self.lags[12], 0.0, 1.0, 100.0, self.world[].sample_rate*0.5)))

        osc2 = self.filt2.lpf(osc2, linexp(self.lags[13], 0.0, 1.0, 100.0, 20000.0), linlin(self.lags[14], 0.0, 1.0, 0.707, 4.0))

        tanh_gain = linlin(self.lags[15], 0.0, 1.0, 0.5, 10.0)
        osc2 = tanh(osc2*tanh_gain)
        osc2 = self.dc2.next(osc2)
        self.fb = osc2

        return MFloat[2](osc1, osc2) * 0.1


# THE GRAPH

struct TorchMlp(Movable, Copyable):
    var world: World
    var torch_synth: TorchSynth  # Instance of the TorchSynth

    def __init__(out self, world: World):
        self.world = world

        self.torch_synth = TorchSynth(self.world)  # Initialize the TorchSynth with the world instance

    def next(mut self) -> MFloat[2]:
        return self.torch_synth.next()