Skip to content

Commit 5fcb044

Browse files
Taylor Shincalclavia
Taylor Shin
authored andcommitted
Volume (#39)
* added volume to the network * change note input dimension after adding volume * implement volume in generate * added jazz to training and generation * mask replay and volume training * experiment with downscaling volume based on majority * midi decode for jazz does not work well hmm * remove computing merged notes
1 parent 1084098 commit 5fcb044

File tree

6 files changed

+111
-47
lines changed

6 files changed

+111
-47
lines changed

constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

33
# Define the musical styles
4-
styles = ['data/baroque', 'data/classical', 'data/romantic', 'data/modern']
4+
styles = ['data/baroque', 'data/classical', 'data/romantic', 'data/modern', 'data/jazz']
55
# styles = ['data/jazz']
66
NUM_STYLES = len(styles)
77

@@ -34,7 +34,7 @@
3434
OCTAVE_UNITS = 32
3535
STYLE_UNITS = 32
3636
BEAT_UNITS = 32
37-
NOTE_UNITS = 2
37+
NOTE_UNITS = 3
3838
TIME_AXIS_UNITS = 300
3939
NOTE_AXIS_UNITS = 150
4040

generate.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,16 @@ def build_note_inputs(self, note_features):
4545
)
4646

4747
def choose(self, prob, n):
48-
prob = apply_temperature(prob, self.temperature)
48+
vol = prob[n, -1]
49+
prob = apply_temperature(prob[n, :-1], self.temperature)
4950

5051
# Flip notes randomly
51-
if np.random.random() <= prob[n, 0]:
52+
if np.random.random() <= prob[0]:
5253
self.next_note[n, 0] = 1
53-
54+
# Apply volume
55+
self.next_note[n, 2] = vol
5456
# Flip articulation
55-
if np.random.random() <= prob[n, 1]:
57+
if np.random.random() <= prob[1]:
5658
self.next_note[n, 1] = 1
5759

5860
def end_time(self, t):
@@ -93,7 +95,7 @@ def process_inputs(ins):
9395
ins = [np.array(i) for i in ins]
9496
return ins
9597

96-
def generate(models, num_bars=32, styles=[[1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1]]):
98+
def generate(models, num_bars=32, styles=[[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1]]):
9799
print('Generating with styles:', styles)
98100

99101
_, time_model, note_model = models

midi_util.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def midi_encode(note_seq, resolution=NOTES_PER_BEAT, step=1):
2020

2121
play = note_seq[:, :, 0]
2222
replay = note_seq[:, :, 1]
23+
volume = note_seq[:, :, 2]
2324

2425
# The current pattern being played
2526
current = np.zeros_like(play[0])
@@ -39,7 +40,7 @@ def midi_encode(note_seq, resolution=NOTES_PER_BEAT, step=1):
3940
# Was off, but now turned on
4041
evt = midi.NoteOnEvent(
4142
tick=(tick - last_event_tick) * step,
42-
velocity=int(next_volume * MAX_VELOCITY),
43+
velocity=int(volume[tick][index[0]] * MAX_VELOCITY),
4344
pitch=index[0]
4445
)
4546
track.append(evt)
@@ -62,7 +63,7 @@ def midi_encode(note_seq, resolution=NOTES_PER_BEAT, step=1):
6263
track.append(evt_off)
6364
evt_on = midi.NoteOnEvent(
6465
tick=0,
65-
velocity=int(current[index] * MAX_VELOCITY),
66+
velocity=int(volume[tick][index[0]] * MAX_VELOCITY),
6667
pitch=index[0]
6768
)
6869
track.append(evt_on)
@@ -103,88 +104,88 @@ def midi_decode(pattern,
103104
step = pattern.resolution // NOTES_PER_BEAT
104105

105106
# Extract all tracks at highest resolution
106-
merged_notes = None
107107
merged_replay = None
108+
merged_volume = None
108109

109110
for track in pattern:
110111
# The downsampled sequences
111-
play_sequence = []
112112
replay_sequence = []
113+
volume_sequence = []
113114

114115
# Raw sequences
115-
play_buffer = [np.zeros((classes,))]
116116
replay_buffer = [np.zeros((classes,))]
117+
volume_buffer = [np.zeros((classes,))]
117118

118119
for i, event in enumerate(track):
119120
# Duplicate the last note pattern to wait for next event
120121
for _ in range(event.tick):
121-
play_buffer.append(np.copy(play_buffer[-1]))
122122
replay_buffer.append(np.zeros(classes))
123+
volume_buffer.append(np.copy(volume_buffer[-1]))
123124

124125
# Buffer & downscale sequence
125-
if len(play_buffer) > step:
126-
# Determine based on majority
127-
notes_sum = np.round(np.sum(play_buffer[:-1], axis=0) / step)
128-
play_sequence.append(play_buffer[0])
129-
130-
# Take the max
126+
if len(volume_buffer) > step:
127+
# Take the min
131128
replay_any = np.minimum(np.sum(replay_buffer[:-1], axis=0), 1)
132129
replay_sequence.append(replay_any)
133130

131+
# Determine volume on rounded sum
132+
volume_sum = np.round(np.sum(volume_buffer[:-1], axis=0) / step)
133+
volume_sequence.append(volume_sum)
134+
134135
# Keep the last one (discard things in the middle)
135-
play_buffer = play_buffer[-1:]
136136
replay_buffer = replay_buffer[-1:]
137+
volume_buffer = volume_buffer[-1:]
137138

138139
if isinstance(event, midi.EndOfTrackEvent):
139140
break
140141

141142
# Modify the last note pattern
142143
if isinstance(event, midi.NoteOnEvent):
143144
pitch, velocity = event.data
144-
play_buffer[-1][pitch] = 1 if velocity > 0 else 0
145+
volume_buffer[-1][pitch] = velocity / MAX_VELOCITY
145146

146147
# Check for replay_buffer, which is true if the current note was previously played and needs to be replayed
147-
if len(play_buffer) > 1 and play_buffer[-2][pitch] > 0 and play_buffer[-1][pitch] > 0:
148+
if len(volume_buffer) > 1 and volume_buffer[-2][pitch] > 0 and volume_buffer[-1][pitch] > 0:
148149
replay_buffer[-1][pitch] = 1
149150
# Override current volume with previous volume
150-
play_buffer[-1][pitch] = play_buffer[-2][pitch]
151+
volume_buffer[-1][pitch] = volume_buffer[-2][pitch]
151152

152153
if isinstance(event, midi.NoteOffEvent):
153154
pitch, velocity = event.data
154-
play_buffer[-1][pitch] = 0
155+
volume_buffer[-1][pitch] = 0
155156

156157
# Add the remaining
157-
play_sequence.append(play_buffer[0])
158158
replay_any = np.minimum(np.sum(replay_buffer, axis=0), 1)
159159
replay_sequence.append(replay_any)
160+
volume_sequence.append(volume_buffer[0])
160161

161-
play_sequence = np.array(play_sequence)
162162
replay_sequence = np.array(replay_sequence)
163-
assert len(play_sequence) == len(replay_sequence)
163+
volume_sequence = np.array(volume_sequence)
164+
assert len(volume_sequence) == len(replay_sequence)
164165

165-
if merged_notes is None:
166-
merged_notes = play_sequence
166+
if merged_volume is None:
167167
merged_replay = replay_sequence
168+
merged_volume = volume_sequence
168169
else:
169170
# Merge into a single track, padding with zeros of needed
170-
if len(play_sequence) > len(merged_notes):
171+
if len(volume_sequence) > len(merged_volume):
171172
# Swap variables such that merged_notes is always at least
172173
# as large as play_sequence
173-
tmp = play_sequence
174-
play_sequence = merged_notes
175-
merged_notes = tmp
176-
177174
tmp = replay_sequence
178175
replay_sequence = merged_replay
179176
merged_replay = tmp
180177

181-
assert len(merged_notes) >= len(play_sequence)
178+
tmp = volume_sequence
179+
volume_sequence = merged_volume
180+
merged_volume = tmp
181+
182+
assert len(merged_volume) >= len(volume_sequence)
182183

183-
diff = len(merged_notes) - len(play_sequence)
184-
merged_notes += np.pad(play_sequence, ((0, diff), (0, 0)), 'constant')
184+
diff = len(merged_volume) - len(volume_sequence)
185185
merged_replay += np.pad(replay_sequence, ((0, diff), (0, 0)), 'constant')
186+
merged_volume += np.pad(volume_sequence, ((0, diff), (0, 0)), 'constant')
186187

187-
merged = np.stack([merged_notes, merged_replay], axis=2)
188+
merged = np.stack([np.ceil(merged_volume), merged_replay, merged_volume], axis=2)
188189
# Prevent stacking duplicate notes to exceed one.
189190
merged = np.minimum(merged, 1)
190191
return merged
@@ -203,13 +204,14 @@ def load_midi(fname):
203204

204205
assert len(note_seq.shape) == 3, note_seq.shape
205206
assert note_seq.shape[1] == MIDI_MAX_NOTES, note_seq.shape
206-
assert note_seq.shape[2] == 2, note_seq.shape
207+
assert note_seq.shape[2] == 3, note_seq.shape
207208
assert (note_seq >= 0).all()
208209
assert (note_seq <= 1).all()
209210
return note_seq
210211

211212
if __name__ == '__main__':
212213
# Test
213-
p = midi.read_midifile("out/test_in.mid")
214+
# p = midi.read_midifile("out/test_in.mid")
215+
p = midi.read_midifile("data/baroque/bach/0864_01.mid")
214216
p = midi_encode(midi_decode(p))
215217
midi.write_midifile("out/test_out.mid", p)

model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
from constants import *
1313

1414
def primary_loss(y_true, y_pred):
15-
return losses.binary_crossentropy(y_true, y_pred)
15+
# 3 separate loss calculations based on if note is played or not
16+
played = y_true[:, :, :, 0]
17+
bce_note = losses.binary_crossentropy(y_true[:, :, :, 0], y_pred[:, :, :, 0])
18+
bce_replay = losses.binary_crossentropy(y_true[:, :, :, 1], tf.multiply(played, y_pred[:, :, :, 1]) + tf.multiply(1 - played, y_true[:, :, :, 1]))
19+
mse = losses.mean_squared_error(y_true[:, :, :, 2], tf.multiply(played, y_pred[:, :, :, 2]) + tf.multiply(1 - played, y_true[:, :, :, 2]))
20+
return bce_note + bce_replay + mse
1621

1722
def style_loss(y_true, y_pred):
1823
return 0.5 * losses.categorical_crossentropy(y_true, y_pred)
@@ -89,7 +94,9 @@ def f(notes, beat, style):
8994
def note_axis(dropout):
9095
dense_layer_cache = {}
9196
lstm_layer_cache = {}
92-
final_dense = Dense(2, activation='sigmoid', name='note_out')
97+
note_dense = Dense(2, activation='sigmoid', name='note_dense')
98+
volume_dense = Dense(1, name='volume_dense')
99+
# final_dense = Concatenate()([note_dense, volume_dense])
93100

94101
def f(x, chosen, style):
95102
time_steps = int(x.get_shape()[1])
@@ -120,7 +127,7 @@ def f(x, chosen, style):
120127
x = Dropout(dropout)(x)
121128

122129
# Primary task
123-
return final_dense(x)
130+
return Concatenate()([note_dense(x), volume_dense(x)])
124131
return f
125132

126133
def style_layer(input_dropout):

test.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,16 @@ def test_encode(self):
2323
[0, 0, 0, 0]
2424
]
2525

26-
pattern = midi_encode(np.stack([composition, replay], 2), step=1)
26+
volume = [
27+
[0, 0.5, 0, 0],
28+
[0, 0.5, 0, 0],
29+
[0, 0.5, 0, 0.5],
30+
[0, 0.5, 0, 0.5],
31+
[0, 0, 0, 0.5],
32+
[0, 0, 0, 0]
33+
]
34+
35+
pattern = midi_encode(np.stack([composition, replay, volume], 2), step=1)
2736
self.assertEqual(pattern.resolution, NOTES_PER_BEAT)
2837
self.assertEqual(len(pattern), 1)
2938
track = pattern[0]
@@ -86,7 +95,16 @@ def test_encode_decode(self):
8695
[0, 0, 0, 0]
8796
]
8897

89-
note_seq = midi_decode(midi_encode(np.stack([composition, replay], 2), step=1), 4, step=1)
98+
volume = [
99+
[0, 0.5, 0, 0],
100+
[0, 0.5, 0, 0],
101+
[0, 0.5, 0, 0.5],
102+
[0, 0.5, 0, 0.5],
103+
[0, 0, 0, 0.5],
104+
[0, 0, 0, 0]
105+
]
106+
107+
note_seq = midi_decode(midi_encode(np.stack([composition, replay, volume], 2), step=1), 4, step=1)
90108
np.testing.assert_array_equal(composition, note_seq[:, :, 0])
91109

92110
def test_replay_decode(self):
@@ -112,6 +130,31 @@ def test_replay_decode(self):
112130
[0., 0., 0., 0.]
113131
])
114132

133+
134+
def test_volume_decode(self):
135+
# Instantiate a MIDI Pattern (contains a list of tracks)
136+
pattern = midi.Pattern(resolution=96)
137+
# Instantiate a MIDI Track (contains a list of MIDI events)
138+
track = midi.Track()
139+
# Append the track to the pattern
140+
pattern.append(track)
141+
142+
track.append(midi.NoteOnEvent(tick=0, velocity=24, pitch=0))
143+
track.append(midi.NoteOnEvent(tick=96, velocity=89, pitch=1))
144+
track.append(midi.NoteOffEvent(tick=0, pitch=0))
145+
track.append(midi.NoteOffEvent(tick=48, pitch=1))
146+
track.append(midi.EndOfTrackEvent(tick=1))
147+
148+
note_seq = midi_decode(pattern, 4, step=DEFAULT_RES // 2)
149+
150+
np.testing.assert_array_almost_equal(note_seq[:, :, 2], [
151+
[24/127, 0., 0., 0.],
152+
[24/127, 0., 0., 0.],
153+
[0., 89/127, 0., 0.],
154+
[0., 0., 0., 0.]
155+
], decimal=5)
156+
157+
115158
def test_replay_encode_decode(self):
116159
# TODO: Fix this test
117160
composition = [
@@ -134,7 +177,17 @@ def test_replay_encode_decode(self):
134177
[0, 0, 0, 0]
135178
]
136179

137-
note_seq = midi_decode(midi_encode(np.stack([composition, replay], 2), step=2), 4, step=2)
180+
volume = [
181+
[0, 0.5, 0, 0.5],
182+
[0, 0, 0, 0.5],
183+
[0, 0, 0, 0.5],
184+
[0, 0.5, 0, 0.5],
185+
[0, 0.5, 0, 0.5],
186+
[0, 0.5, 0, 0.5],
187+
[0, 0, 0, 0]
188+
]
189+
190+
note_seq = midi_decode(midi_encode(np.stack([composition, replay, volume], 2), step=2), 4, step=2)
138191
np.testing.assert_array_equal(composition, note_seq[:, :, 0])
139192
# TODO: Downsampling might have caused loss of information
140193
# np.testing.assert_array_equal(replay, note_seq[:, :, 1])

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def epoch_cb(epoch, _):
2828
write_file(os.path.join(SAMPLES_DIR, 'epoch_{}.mid'.format(epoch)), generate(models))
2929

3030
cbs = [
31-
ModelCheckpoint(MODEL_FILE, monitor='loss', save_best_only=True),
31+
ModelCheckpoint(MODEL_FILE, monitor='loss', save_best_only=True, save_weights_only=True),
3232
EarlyStopping(monitor='loss', patience=5),
3333
TensorBoard(log_dir='out/logs', histogram_freq=1)
3434
]

0 commit comments

Comments
 (0)