Skip to content

sgnts.transforms.converter

Converter dataclass

Bases: TSTransform

Change the data type or the device of the data.

Parameters:

Name Type Description Default
backend str

str, the backend to convert the data to. Supported backends: ['numpy'|'torch']

'numpy'
dtype str

str, the data type to convert the data to. Supported dtypes: ['float32'|'float16']

'float32'
device str

str, the device to convert the data to. Suppored devices: if backend = 'numpy', only supports device = 'cpu', if backend = 'torch', supports device = ['cpu'|'cuda'|'cuda:'] where is the GPU device number.

'cpu'
Source code in sgnts/transforms/converter.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
@dataclass
class Converter(TSTransform):
    """Change the data type or the device of the data.

    Args:
        backend:
            str, the backend to convert the data to. Supported backends:
            ['numpy'|'torch']
        dtype:
            str, the data type to convert the data to. Supported dtypes:
            ['float32'|'float16']
        device:
            str, the device to convert the data to. Suppored devices:
            if backend = 'numpy', only supports device = 'cpu', if backend = 'torch',
            supports device = ['cpu'|'cuda'|'cuda:<GPU number>'] where <GPU number> is
            the GPU device number.
    """

    backend: str = "numpy"
    dtype: str = "float32"
    device: str = "cpu"

    def __post_init__(self):
        assert set(self.source_pad_names) == set(self.sink_pad_names)
        super().__post_init__()

        if self.backend == "numpy":
            if self.device != "cpu":
                raise ValueError("Converting to numpy only supports device as cpu")
        elif self.backend == "torch":
            if isinstance(self.dtype, str):
                if self.dtype == "float64":
                    self.dtype = torch.float64
                elif self.dtype == "float32":
                    self.dtype = torch.float32
                elif self.dtype == "float16":
                    self.dtype = torch.float16
                else:
                    raise ValueError(
                        "Supported torch data types: float64, float32, float16"
                    )
            elif isinstance(self.dtype, torch.dtype):
                pass
            else:
                raise ValueError("Unknown dtype")
        else:
            raise ValueError("Supported backends: 'numpy' or 'torch'")

        self.pad_map = {
            p: self.sink_pad_dict["%s:sink:%s" % (self.name, p.name.split(":")[-1])]
            for p in self.source_pads
        }

    @wraps(TSTransform.new)
    def new(self, pad):
        frame = self.preparedframes[self.pad_map[pad]]
        self.preparedframes[self.pad_map[pad]] = None

        outbufs = []
        for buf in frame:
            if buf.is_gap:
                out = None
            else:
                data = buf.data
                if self.backend == "numpy":
                    if isinstance(data, np.ndarray):
                        # numpy to numpy
                        out = data.astype(self.dtype, copy=False)
                    elif isinstance(data, torch.Tensor):
                        # torch to numpy
                        out = data.detach().cpu().numpy().astype(self.dtype, copy=False)
                    else:
                        raise ValueError("Unsupported data type")
                else:
                    if isinstance(data, np.ndarray):
                        # numpy to torch
                        out = torch.from_numpy(data).to(self.dtype).to(self.device)
                    elif isinstance(data, torch.Tensor):
                        # torch to torch
                        out = data.to(self.dtype).to(self.device)
                    else:
                        raise ValueError("Unsupported data type")

            outbufs.append(
                SeriesBuffer(
                    offset=buf.offset,
                    sample_rate=buf.sample_rate,
                    data=out,
                    shape=buf.shape,
                )
            )

        return TSFrame(
            buffers=outbufs,
            metadata=frame.metadata,
            EOS=frame.EOS,
        )