diff --git a/__pycache__/test.cpython-311.pyc b/__pycache__/test.cpython-311.pyc new file mode 100644 index 0000000..fa4ed36 Binary files /dev/null and b/__pycache__/test.cpython-311.pyc differ diff --git a/__pycache__/utils.cpython-311.pyc b/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000..ebcc036 Binary files /dev/null and b/__pycache__/utils.cpython-311.pyc differ diff --git a/data/__pycache__/__init__.cpython-311.pyc b/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..97083b5 Binary files /dev/null and b/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/data/__pycache__/dataset.cpython-311.pyc b/data/__pycache__/dataset.cpython-311.pyc new file mode 100644 index 0000000..e3a9686 Binary files /dev/null and b/data/__pycache__/dataset.cpython-311.pyc differ diff --git a/data/__pycache__/musdb.cpython-311.pyc b/data/__pycache__/musdb.cpython-311.pyc new file mode 100644 index 0000000..9794c87 Binary files /dev/null and b/data/__pycache__/musdb.cpython-311.pyc differ diff --git a/data/__pycache__/utils.cpython-311.pyc b/data/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000..ffa1a9d Binary files /dev/null and b/data/__pycache__/utils.cpython-311.pyc differ diff --git a/model/__pycache__/__init__.cpython-311.pyc b/model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..b03d70f Binary files /dev/null and b/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/model/__pycache__/conv.cpython-311.pyc b/model/__pycache__/conv.cpython-311.pyc new file mode 100644 index 0000000..e729502 Binary files /dev/null and b/model/__pycache__/conv.cpython-311.pyc differ diff --git a/model/__pycache__/crop.cpython-311.pyc b/model/__pycache__/crop.cpython-311.pyc new file mode 100644 index 0000000..ec9da62 Binary files /dev/null and b/model/__pycache__/crop.cpython-311.pyc differ diff --git a/model/__pycache__/resample.cpython-311.pyc b/model/__pycache__/resample.cpython-311.pyc new file mode 100644 index 0000000..c562864 Binary files /dev/null and b/model/__pycache__/resample.cpython-311.pyc differ diff --git a/model/__pycache__/utils.cpython-311.pyc b/model/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000..bd5fe7f Binary files /dev/null and b/model/__pycache__/utils.cpython-311.pyc differ diff --git a/model/__pycache__/waveunet.cpython-311.pyc b/model/__pycache__/waveunet.cpython-311.pyc new file mode 100644 index 0000000..32b2419 Binary files /dev/null and b/model/__pycache__/waveunet.cpython-311.pyc differ diff --git a/model/utils.py b/model/utils.py index 4c92c95..bc7738b 100644 --- a/model/utils.py +++ b/model/utils.py @@ -13,13 +13,13 @@ def save_model(model, optimizer, state, path): }, path) -def load_model(model, optimizer, path, cuda): + +def load_model(model, optimizer, path, device): if isinstance(model, torch.nn.DataParallel): model = model.module # load state dict of wrapped module - if cuda: - checkpoint = torch.load(path) - else: - checkpoint = torch.load(path, map_location='cpu') + + checkpoint = torch.load(path, map_location=device) + try: model.load_state_dict(checkpoint['model_state_dict']) except: @@ -32,8 +32,13 @@ def load_model(model, optimizer, path, cuda): k = k[len(prefix):] model_state_dict_fixed[k] = v model.load_state_dict(model_state_dict_fixed) + + # Ensure the model is moved to the correct device + model.to(device) + if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + if 'state' in checkpoint: state = checkpoint['state'] else: diff --git a/model/waveunet.py b/model/waveunet.py index a14aa55..2391f5d 100644 --- a/model/waveunet.py +++ b/model/waveunet.py @@ -100,9 +100,10 @@ def get_input_size(self, output_size): return curr_size class Waveunet(nn.Module): - def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2): + def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2, device=None): super(Waveunet, self).__init__() + self.device = device if device is not None else torch.device('cpu') self.num_levels = len(num_channels) self.strides = strides self.kernel_size = kernel_size @@ -111,7 +112,7 @@ def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_si self.depth = depth self.instruments = instruments self.separate = separate - + self.to(self.device) # Only odd filter kernels allowed assert(kernel_size % 2 == 1) @@ -195,9 +196,10 @@ def forward_module(self, x, module): :param module: Network module to be used for prediction :return: Source estimates ''' + x = x.to(self.device) shortcuts = [] out = x - + #print(x.shape) # DOWNSAMPLING BLOCKS for block in module.downsampling_blocks: out, short = block(out) diff --git a/train.py b/train.py index 391d379..3164f6f 100644 --- a/train.py +++ b/train.py @@ -22,19 +22,21 @@ def main(args): #torch.backends.cudnn.benchmark=True # This makes dilated conv much faster for CuDNN 7.5 - + device = "cuda" if args.cuda else "cpu" + if args.mps: + device = torch.device('mps') # MODEL num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \ [args.features*2**i for i in range(0, args.levels)] target_outputs = int(args.output_size * args.sr) model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size, - target_output_size=target_outputs, depth=args.depth, strides=args.strides, - conv_type=args.conv_type, res=args.res, separate=args.separate) + target_output_size=target_outputs, depth=args.depth, strides=args.strides, + conv_type=args.conv_type, res=args.res, separate=args.separate, device=device).to(device) - if args.cuda: + if args.cuda or args.mps: model = model_utils.DataParallel(model) - print("move model to gpu") - model.cuda() + print(f"move model to {device}") + model.to(device) print('model: ', model) print('parameter count: ', str(sum(p.numel() for p in model.parameters()))) @@ -75,7 +77,7 @@ def main(args): # LOAD MODEL CHECKPOINT IF DESIRED if args.load_model is not None: print("Continuing training full model from checkpoint " + str(args.load_model)) - state = model_utils.load_model(model, optimizer, args.load_model, args.cuda) + state = model_utils.load_model(model, optimizer, args.load_model, device) print('TRAINING START') while state["worse_epochs"] < args.patience: @@ -85,10 +87,9 @@ def main(args): with tqdm(total=len(train_data) // args.batch_size) as pbar: np.random.seed() for example_num, (x, targets) in enumerate(dataloader): - if args.cuda: - x = x.cuda() - for k in list(targets.keys()): - targets[k] = targets[k].cuda() + x = x.to(device) + for k in list(targets.keys()): + targets[k] = targets[k].to(device) t = time.time() @@ -139,13 +140,12 @@ def main(args): print("Saving model...") model_utils.save_model(model, optimizer, state, checkpoint_path) - #### TESTING #### # Test loss print("TESTING") # Load best model based on validation loss - state = model_utils.load_model(model, None, state["best_checkpoint"], args.cuda) + state = model_utils.load_model(model, None, state["best_checkpoint"], device) test_loss = validate(args, model, criterion, test_data) print("TEST FINISHED: LOSS: " + str(test_loss)) writer.add_scalar("test_loss", test_loss, state["step"]) @@ -176,6 +176,8 @@ def main(args): help="List of instruments to separate (default: \"bass drums other vocals\")") parser.add_argument('--cuda', action='store_true', help='Use CUDA (default: False)') + parser.add_argument('--mps', action='store_true', + help='Use MPS on M1 Mac (default: False)') parser.add_argument('--num_workers', type=int, default=1, help='Number of data loader worker threads (default: 1)') parser.add_argument('--features', type=int, default=32,