Skip to content

For more information about the examples, such as how the Python and Mojo files interact with each other, see the Examples Overview

TorchMlp

this examples uses a Torch MLP model to control a 16 parameter synth to play the synth, just hang out in the top 4 lines of code and play with the mouse

you can also train the synth by creating any number of input/output pairs and making a new training

Python Code

if True:
    from mmm_src.MMMAudio import MMMAudio
    from random import random

    mmm_audio = MMMAudio(128, graph_name="TorchMlp", package_name="examples")

    # this one is a bit intense, so maybe start with a low volume
    mmm_audio.start_audio()

mmm_audio.stop_audio()

# below is the code to make a new training --------------------------------

# toggle inference off so you can set the synth values directly
mmm_audio.send_bool("mlp1.toggle_inference", True)
mmm_audio.send_bool("mlp1.toggle_inference", False)

# how many outputs does your mlp have?
out_size = 16

# create lists to hold your training data
X_train_list = []
y_train_list = []

def make_setting():
    setting = []
    for _ in range(out_size):
        setting.append(random())
    print("setting =", setting)
    mmm_audio.send_floats("mlp1.fake_model_output", setting)

    return setting

# create an output setting to train on
outputs = make_setting()

# print out what you have so far
for i in range(len(y_train_list)):
    print(f"Element {i}: {X_train_list[i]}")
    print(f"Element {i}: {y_train_list[i]}")

# when you like a setting add an input and output pair
# this is assuming you are training on 4 pairs of data points - you do as many as you like

outputs = make_setting()

X_train_list.append([0,0])
y_train_list.append(outputs)

outputs = make_setting()

X_train_list.append([0,1])
y_train_list.append(outputs)

outputs = make_setting()

X_train_list.append([1,1])
y_train_list.append(outputs)

outputs = make_setting()

X_train_list.append([1,0])
y_train_list.append(outputs)

# once you have filled the X_train_list and y_train_list, train the network on your data

def do_the_training():
    print("training the network")
    learn_rate = 0.001
    epochs = 5000

    layers = [ [ 64, "relu" ], [ 64, "relu" ], [ out_size, "sigmoid" ] ]

    # train the network in a separate thread so the audio thread doesn't get interrupted

    from mmm_utils.mlp_trainer import train_nn
    import threading

    target_function = train_nn
    args = (X_train_list, y_train_list, layers, learn_rate, epochs, "examples/nn_trainings/model_traced.pt")

    # Create a Thread object
    training_thread = threading.Thread(target=target_function, args=args)
    training_thread.start()

do_the_training()

# load the new training into the synth
mmm_audio.send_string("mlp1.load_mlp_training", "examples/nn_trainings/model_traced.pt")  

# toggle inference off so you can set the synth values directly
mmm_audio.send_bool("mlp1.toggle_inference", True)

Mojo Code

from mmm_src.MMMWorld import MMMWorld
from mmm_utils.functions import *
from mmm_utils.Messenger import *

from mmm_dsp.Osc import *
from math import tanh
from random import random_float64
from mmm_dsp.Filters import *

from mmm_dsp.MLP import MLP
from mmm_dsp.Distortion import *
from sys import simd_width_of

alias simd_width = simd_width_of[DType.float64]() * 2
alias model_out_size = 16  # Define the output size of the model
alias num_simd = (model_out_size + simd_width - 1) // simd_width  # Calculate number of SIMD groups needed

# THE SYNTH - is imported from TorchSynth.mojo in this directory
struct TorchSynth(Movable, Copyable):
    var world: UnsafePointer[MMMWorld]  # Pointer to the MMMWorld instance
    var osc1: Osc[1, 2, 1]
    var osc2: Osc[1, 2, 1]

    var model: MLP[2, model_out_size]  # Instance of the MLP model - 2 inputs, model_out_size outputs
    var lags: List[Lag[simd_width]]  
    var lag_vals: List[Float64]

    var fb: Float64

    var latch1: Latch
    var latch2: Latch
    var impulse1: Impulse
    var impulse2: Impulse

    var filt1: SVF
    var filt2: SVF

    var dc1: DCTrap
    var dc2: DCTrap

    fn __init__(out self, world: UnsafePointer[MMMWorld]):
        self.world = world
        self.osc1 = Osc[1, 2, 1](world)
        self.osc2 = Osc[1, 2, 1](world)

        # load the trained model
        self.model = MLP(world,"examples/nn_trainings/model_traced.pt", "mlp1", trig_rate=25.0)

        # make a lag for each output of the nn - pair them in twos for SIMD processing
        # self.lag_vals = InlineArray[Float64, model_out_size](fill=random_float64())
        self.lag_vals = [random_float64() for _ in range(model_out_size)]
        self.lags = [Lag[simd_width](self.world, 1/25.0) for _ in range(num_simd)]

        # create a feedback variable so each of the oscillators can feedback on each sample
        self.fb = 0.0

        self.latch1 = Latch(self.world)
        self.latch2 = Latch(self.world)
        self.impulse1 = Impulse(self.world)
        self.impulse2 = Impulse(self.world)
        self.filt1 = SVF(self.world)
        self.filt2 = SVF(self.world)
        self.dc1 = DCTrap(self.world)
        self.dc2 = DCTrap(self.world)

    @always_inline
    fn next(mut self) -> SIMD[DType.float64, 2]:
        self.model.model_input[0] = self.world[].mouse_x
        self.model.model_input[1] = self.world[].mouse_y

        self.model.next()  # Run the model inference

        @parameter
        for i in range(num_simd):
            # process each lag group
            model_output_simd = SIMD[DType.float64, simd_width](0.0)
            for j in range(simd_width):
                idx = i * simd_width + j
                if idx < model_out_size:
                    model_output_simd[j] = self.model.model_output[idx]
            lagged_output = self.lags[i].next(model_output_simd)
            for j in range(simd_width):
                idx = i * simd_width + j
                if idx < model_out_size:
                    self.lag_vals[idx] = lagged_output[j]

        # uncomment to see the output of the model
        # self.world[].print(self.lag_vals[0], self.lag_vals[1], self.lag_vals[2], self.lag_vals[3], self.lag_vals[4], self.lag_vals[5], self.lag_vals[6], self.lag_vals[7], self.lag_vals[8], self.lag_vals[9], self.lag_vals[10], self.lag_vals[11], self.lag_vals[12], self.lag_vals[13], self.lag_vals[14], self.lag_vals[15])

        # oscillator 1 -----------------------

        var freq1 = linexp(self.lag_vals[0], 0.0, 1.0, 1.0, 3000) + (linlin(self.lag_vals[1], 0.0, 1.0, 2.0, 5000.0) * self.fb)
        # var which_osc1 = lag_vals[2] #not used...whoops

        # next_interp implements a variable wavetable oscillator between the N provided wave types
        # in this case, we are using 0, 4, 5, 6 - Sine, BandLimited Tri, BL Saw, BL Square
        osc_frac1 = linlin(self.lag_vals[3], 0.0, 1.0, 0.0, 1.0)
        osc1 = self.osc1.next_interp(freq1, 0.0, False, [0,4,5,6], osc_frac1)

        # samplerate reduction
        osc1 = self.latch1.next(osc1, self.impulse1.next_bool(linexp(self.lag_vals[4], 0.0, 1.0, 100.0, self.world[].sample_rate*0.5)))
        osc1 = self.filt1.lpf(osc1, linexp(self.lag_vals[5], 0.0, 1.0, 100.0, 20000.0), linlin(self.lag_vals[6], 0.0, 1.0, 0.707, 4.0))

        tanh_gain = linlin(self.lag_vals[7], 0.0, 1.0, 0.5, 10.0)

        # get rid of dc offset
        osc1 = tanh(osc1*tanh_gain)
        osc1 = self.dc1.next(osc1)

        # oscillator 2 -----------------------

        var freq2 = linlin(self.lag_vals[8], 0.0, 1.0, 2.0, 5000.0) + (linlin(self.lag_vals[9], 0.0, 1.0, 2.0, 5000.0) * osc1)
        # var which_osc2 = self.lag_vals[10] #not used...whoops

        osc_frac2 = linlin(self.lag_vals[11], 0.0, 1.0, 0.0, 1.0)
        var osc2 = self.osc2.next_interp(freq2, 0.0, False, [0,4,5,6], osc_frac2)

        osc2 = self.latch2.next(osc2, self.impulse2.next_bool(linexp(self.lag_vals[12], 0.0, 1.0, 100.0, self.world[].sample_rate*0.5)))

        osc2 = self.filt2.lpf(osc2, linexp(self.lag_vals[13], 0.0, 1.0, 100.0, 20000.0), linlin(self.lag_vals[14], 0.0, 1.0, 0.707, 4.0))

        tanh_gain = linlin(self.lag_vals[15], 0.0, 1.0, 0.5, 10.0)
        osc2 = tanh(osc2*tanh_gain)
        osc2 = self.dc2.next(osc2)
        self.fb = osc2

        return SIMD[DType.float64, 2](osc1, osc2) * 0.1


# THE GRAPH

struct TorchMlp():
    var world: UnsafePointer[MMMWorld]
    var torch_synth: TorchSynth  # Instance of the TorchSynth

    fn __init__(out self, world: UnsafePointer[MMMWorld]):
        self.world = world

        self.torch_synth = TorchSynth(world)  # Initialize the TorchSynth with the world instance

    fn next(mut self) -> SIMD[DType.float64, 2]:
        return self.torch_synth.next()