| import torch |
| import megablocks |
|
|
| def test_import(): |
| """Simple test to check if the module can be imported.""" |
| print("megablocks_moe module imported successfully.") |
| print("Available functions:", dir(megablocks)) |
|
|
| expected_functions = [ |
| "Arguments", "MLP", "MoE", "ParallelDroplessMLP", "ParallelMLP", |
| "SparseGLU", "SparseMLP", "argsort", |
| "backend", "cumsum", "dMoE", "exclusive_cumsum", |
| "get_load_balancing_loss", "grouped_gemm_util", "histogram", |
| "inclusive_cumsum", "indices", "layers", "ops", "replicate_backward", |
| "replicate_forward", "sort", "torch" |
| ] |
|
|
| |
| for func in expected_functions: |
| assert func in dir(megablocks), f"Missing function: {func}" |
|
|
| |
| def test_exclusive_cumsum(): |
| """Test exclusive cumulative sum.""" |
| x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda() |
| out = torch.empty_like(x) |
| megablocks.exclusive_cumsum(x, 0, out) |
| expected = torch.tensor([0, 1, 3, 6], dtype=torch.float32).cuda() |
| assert torch.equal(out, expected), f"Expected {expected}, got {out}" |
| print("cumsum output:", out) |
|
|
| |
| def test_inclusive_cumsum(): |
| """Test inclusive cumulative sum.""" |
| x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda() |
| out = torch.empty_like(x) |
| megablocks.inclusive_cumsum(x, dim=0, out=out) |
| expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32).cuda() |
| assert torch.equal(out, expected), f"Expected {expected}, got {out}" |
|
|
| |
| def test_histogram(): |
| """Test histogram operation.""" |
| x = torch.tensor([0, 1, 1, 2, 2, 2], dtype=torch.int16).cuda() |
| num_bins = 3 |
| hist = megablocks.histogram(x, num_bins) |
| expected_hist = torch.tensor([1, 2, 3], dtype=torch.int32).cuda() |
| assert torch.equal(hist, expected_hist), f"Expected {expected_hist}, got {hist}" |
|
|