Skip to content

Commit fc18bdd

Browse files
committed
Refactored into train and generate.py
1 parent f594da7 commit fc18bdd

File tree

5 files changed

+75
-67
lines changed

5 files changed

+75
-67
lines changed

dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from constants import *
1111
from midi_util import load_midi
12-
from util import chunk, get_all_files, one_hot
12+
from util import *
1313

1414
def compute_beat(beat, notes_in_bar):
1515
return one_hot(beat % notes_in_bar, notes_in_bar)

generate.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import numpy as np
2+
import tensorflow as tf
23
from collections import deque
34
import midi
5+
import argparse
46

57
from constants import *
8+
from util import *
69
from dataset import *
710
from tqdm import tqdm
811
from midi_util import midi_encode
@@ -90,7 +93,7 @@ def process_inputs(ins):
9093
ins = [np.array(i) for i in ins]
9194
return ins
9295

93-
def generate(models, num_bars=16, styles=[0, 4, 12, 20]):
96+
def generate(models, num_bars=16, styles=[[1,0,0,0]]):
9497
print('Generating with styles:', styles)
9598

9699
_, time_model, note_model = models
@@ -115,14 +118,26 @@ def generate(models, num_bars=16, styles=[0, 4, 12, 20]):
115118
# Move one time step
116119
yield [g.end_time(t) for g in generations]
117120

118-
def write_file(name, results):
121+
def write_file(fpath, results):
119122
"""
120123
Takes a list of all notes generated per track and writes it to file
121124
"""
122125
results = zip(*list(results))
123126

124127
for i, result in enumerate(results):
125-
fpath = SAMPLES_DIR + '/' + name + '_' + str(i) + '.mid'
126128
print('Writing file', fpath)
129+
os.makedirs(os.path.dirname(fpath), exist_ok=True)
127130
mf = midi_encode(unclamp_midi(result))
128131
midi.write_midifile(fpath, mf)
132+
133+
def main():
134+
# parser = argparse.ArgumentParser(description='Generates music.')
135+
# parser.add_argument('--gen', default=None, nargs='+', help='Style to generate')
136+
# args = parser.parse_args()
137+
138+
with tf.device('cpu:0'):
139+
models = build_or_load()
140+
write_file(os.path.join(SAMPLES_DIR, 'output.mid'), generate(models))
141+
142+
if __name__ == '__main__':
143+
main()

model.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,46 @@ def f(x):
4646
return bins
4747
return f
4848

49+
def time_axis(dropout):
50+
def f(notes, beat, style):
51+
time_steps = int(notes.get_shape()[1])
52+
53+
# TODO: Experiment with when to apply conv
54+
note_octave = TimeDistributed(Conv1D(OCTAVE_UNITS, 2 * OCTAVE, padding='same'))(notes)
55+
note_octave = Activation('tanh')(note_octave)
56+
note_octave = Dropout(dropout)(note_octave)
57+
58+
# Create features for every single note.
59+
note_features = Concatenate()([
60+
Lambda(pitch_pos_in_f(time_steps))(notes),
61+
Lambda(pitch_class_in_f(time_steps))(notes),
62+
Lambda(pitch_bins_f(time_steps))(notes),
63+
note_octave,
64+
TimeDistributed(RepeatVector(NUM_NOTES))(beat)
65+
])
66+
67+
x = note_features
68+
69+
# [batch, notes, time, features]
70+
x = Permute((2, 1, 3))(x)
71+
72+
# Apply LSTMs
73+
for l in range(TIME_AXIS_LAYERS):
74+
# Integrate style
75+
style_proj = Dense(int(x.get_shape()[3]))(style)
76+
style_proj = Activation('tanh')(style_proj)
77+
style_proj = Dropout(dropout)(style_proj)
78+
style_proj = TimeDistributed(RepeatVector(NUM_NOTES))(style_proj)
79+
style_proj = Permute((2, 1, 3))(style_proj)
80+
x = Add()([x, style_proj])
81+
82+
x = TimeDistributed(LSTM(TIME_AXIS_UNITS, return_sequences=True))(x)
83+
x = Dropout(dropout)(x)
84+
85+
# [batch, time, notes, features]
86+
return Permute((2, 1, 3))(x)
87+
return f
88+
4989
def note_axis(dropout):
5090
def f(x, chosen, style):
5191
time_steps = int(x.get_shape()[1])
@@ -74,7 +114,7 @@ def f(x, chosen, style):
74114
return f
75115

76116
def style_layer(input_dropout):
77-
emb = Embedding(NUM_STYLES, STYLE_UNITS)
117+
emb = Dense(STYLE_UNITS)
78118
def f(style_in):
79119
style = emb(style_in)
80120
return Dropout(input_dropout)(style)
@@ -83,7 +123,7 @@ def f(style_in):
83123
def build_models(time_steps=SEQ_LEN, input_dropout=0.2, dropout=0.5):
84124
notes_in = Input((time_steps, NUM_NOTES, NOTE_UNITS))
85125
beat_in = Input((time_steps, NOTES_PER_BAR))
86-
style_in = Input((time_steps,))
126+
style_in = Input((time_steps, NUM_STYLES))
87127
# Target input for conditioning
88128
chosen_in = Input((time_steps, NUM_NOTES, NOTE_UNITS))
89129

@@ -97,40 +137,7 @@ def build_models(time_steps=SEQ_LEN, input_dropout=0.2, dropout=0.5):
97137
style = style_l(style_in)
98138

99139
""" Time axis """
100-
# TODO: Experiment with when to apply conv
101-
note_octave = TimeDistributed(Conv1D(OCTAVE_UNITS, 2 * OCTAVE, padding='same'))(notes)
102-
note_octave = Activation('tanh')(note_octave)
103-
note_octave = Dropout(dropout)(note_octave)
104-
105-
# Create features for every single note.
106-
note_features = Concatenate()([
107-
Lambda(pitch_pos_in_f(time_steps))(notes),
108-
Lambda(pitch_class_in_f(time_steps))(notes),
109-
Lambda(pitch_bins_f(time_steps))(notes),
110-
note_octave,
111-
TimeDistributed(RepeatVector(NUM_NOTES))(beat)
112-
])
113-
114-
x = note_features
115-
116-
# [batch, notes, time, features]
117-
x = Permute((2, 1, 3))(x)
118-
119-
# Apply LSTMs
120-
for l in range(TIME_AXIS_LAYERS):
121-
# Integrate style
122-
style_proj = Dense(int(x.get_shape()[3]))(style)
123-
style_proj = Activation('tanh')(style_proj)
124-
style_proj = Dropout(dropout)(style_proj)
125-
style_proj = TimeDistributed(RepeatVector(NUM_NOTES))(style_proj)
126-
style_proj = Permute((2, 1, 3))(style_proj)
127-
x = Add()([x, style_proj])
128-
129-
x = TimeDistributed(LSTM(TIME_AXIS_UNITS, return_sequences=True))(x)
130-
x = Dropout(dropout)(x)
131-
132-
# [batch, time, notes, features]
133-
time_out = Permute((2, 1, 3))(x)
140+
time_out = time_axis(dropout)(notes, beat, style)
134141

135142
""" Note Axis & Prediction Layer """
136143
naxis = note_axis(dropout)
@@ -144,7 +151,7 @@ def build_models(time_steps=SEQ_LEN, input_dropout=0.2, dropout=0.5):
144151

145152
note_features = Input((1, NUM_NOTES, TIME_AXIS_UNITS), name='note_features')
146153
chosen_gen_in = Input((1, NUM_NOTES, NOTE_UNITS), name='chosen_gen_in')
147-
style_gen_in = Input((1,), name='style_in')
154+
style_gen_in = Input((1, NUM_STYLES), name='style_in')
148155

149156
# Dropout inputs
150157
chosen_gen = Dropout(input_dropout)(chosen_gen_in)

main.py renamed to train.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,11 @@
1313

1414
def main():
1515
parser = argparse.ArgumentParser(description='Generates music.')
16-
parser.add_argument('--train', default=False, action='store_true', help='Train model?')
1716
parser.add_argument('--gen', default=False, action='store_true', help='Generate after each epoch?')
1817
args = parser.parse_args()
1918

2019
models = build_or_load()
21-
22-
if args.train:
23-
train(models, args.gen)
24-
else:
25-
write_file(os.path.join(SAMPLES_DIR, 'output.mid'), generate(models))
26-
27-
def build_or_load(allow_load=True):
28-
models = build_models()
29-
models[0].summary()
30-
if allow_load:
31-
try:
32-
models[0].load_weights(MODEL_FILE)
33-
print('Loaded model from file.')
34-
except:
35-
print('Unable to load model from file.')
36-
return models
20+
train(models, args.gen)
3721

3822
def train(models, gen):
3923
print('Loading data')

util.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@ def one_hot(i, nb_classes):
1010
arr[i] = 1
1111
return arr
1212

13-
def chunk(a, size):
14-
# Zero pad extra spaces
15-
target_size = math.ceil(len(a) / float(size)) * size
16-
inc_size = target_size - len(a)
17-
assert inc_size >= 0 and inc_size < size, inc_size
18-
a = np.array(a)
19-
a = np.pad(a, [(0, inc_size)] + [(0, 0) for i in range(len(a.shape) - 1)], mode='constant')
20-
assert a.shape[0] == target_size
21-
return np.swapaxes(np.split(a, size), 0, 1)
13+
def build_or_load(allow_load=True):
14+
from model import build_models
15+
models = build_models()
16+
models[0].summary()
17+
if allow_load:
18+
try:
19+
models[0].load_weights(MODEL_FILE)
20+
print('Loaded model from file.')
21+
except:
22+
print('Unable to load model from file.')
23+
return models
2224

2325
def get_all_files(paths):
2426
potential_files = []

0 commit comments

Comments
 (0)