[docs]@torch_jit_script_unless_coveragedefbatch_to(systems:List[System],targets:Dict[str,TensorMap],extra_data:Optional[Dict[str,TensorMap]]=None,dtype:Optional[torch.dtype]=None,device:Optional[torch.device]=None,):""" Changes the systems and targets to the specified floating point data type. :param systems: List of systems. :param targets: Dictionary of targets. :param dtype: Desired floating point data type. """systems=[system.to(dtype=dtype,device=device)forsysteminsystems]targets={key:value.to(dtype=dtype,device=device)forkey,valueintargets.items()}ifextra_dataisnotNone:new_dtypes:List[Optional[int]]=[]forkeyinextra_data.keys():ifkey.endswith("_mask"):# masks should always be booleannew_dtypes.append(torch.bool)else:new_dtypes.append(dtype)extra_data={key:value.to(dtype=_dtype,device=device)for(key,value),_dtypeinzip(extra_data.items(),new_dtypes)}returnsystems,targets,extra_data