Coverage for tests\test_get_module_device.py: 70%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-14 12:57 -0700

1from __future__ import annotations 

2 

3import pytest 

4import torch 

5 

6from zanj.torchutil import get_module_device 

7 

8 

9def test_get_module_device_single_device(): 

10 # Create a model and move it to a device 

11 model = torch.nn.Linear(10, 2) 

12 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

13 model.to(device) 

14 

15 # Run the function 

16 is_single, device_or_dict = get_module_device(model) 

17 

18 # Assert that all parameters are on the same device and that device is returned 

19 assert is_single 

20 assert device_or_dict == device 

21 

22 

23def test_get_module_device_multiple_devices(): 

24 # Create a model with parameters on different devices 

25 if torch.cuda.device_count() < 1: 

26 pytest.skip("This test requires at least one CUDA device") 

27 

28 model = torch.nn.Linear(10, 2) 

29 model.weight.to("cuda:0") 

30 model.bias.to("cpu") 

31 

32 # Run the function 

33 is_single, device_or_dict = get_module_device(model) 

34 

35 # Assert that not all parameters are on the same device and a dict is returned 

36 assert not is_single 

37 assert isinstance(device_or_dict, dict) 

38 

39 # Check that the dict maps the correct devices 

40 assert device_or_dict["weight"] == torch.device("cuda:0") 

41 assert device_or_dict["bias"] == torch.device("cpu") 

42 

43 

44def test_get_module_device_no_parameters(): 

45 # Create a model with no parameters 

46 model = torch.nn.Sequential() 

47 

48 # Run the function 

49 is_single, device_or_dict = get_module_device(model) 

50 

51 # Assert that an empty dict is returned 

52 assert not is_single 

53 assert device_or_dict == {}