diff --git a/pytools/__init__.py b/pytools/__init__.py index bfe6cee0ae4c729e904ecfb0d5007028f9f29f45..e018f949836cef6420c50ecf75d24e6aed5fe864 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -179,6 +179,11 @@ Sampling .. autofunction:: sphere_sample_equidistant .. autofunction:: sphere_sample_fibonacci +String utilities +---------------- + +.. autofunction:: strtobool + Type Variables Used ------------------- @@ -2907,6 +2912,44 @@ def sphere_sample_fibonacci( # }}} +# {{{ strtobool + +def strtobool(val: Optional[str], default: Optional[bool] = None) -> bool: + """Convert a string representation of truth to True or False. + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Uppercase versions are + also accepted. If *default* is None, raises ValueError if *val* is anything + else. If *val* is None and *default* is not None, returns *default*. + Based on :func:`distutils.util.strtobool`. + + :param val: Value to convert. + :param default: Value to return if *val* is None. + + :returns: Truth value of *val*. + """ + if val is None and default is not None: + return default + + if val is None: + raise ValueError(f"invalid truth value '{val}'. " + "Valid values are ('y', 'yes', 't', 'true', 'on', '1') " + "for 'True' and ('n', 'no', 'f', 'false', 'off', '0') " + "for 'False'. Uppercase versions are also accepted.") + + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError(f"invalid truth value '{val}'. " + "Valid values are ('y', 'yes', 't', 'true', 'on', '1') " + "for 'True' and ('n', 'no', 'f', 'false', 'off', '0') " + "for 'False'. Uppercase versions are also accepted.") + +# }}} + + def _test(): import doctest doctest.testmod() diff --git a/test/test_pytools.py b/test/test_pytools.py index 214e014817a86ed3547689ba9248473461bcc38b..d89e9c9ed77d781d56428ebecc812b9bdb3faf35 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -703,6 +703,29 @@ def test_ignoredforequalitytag(): assert hash(eq1) != hash(eq3) +def test_strtobool(): + from pytools import strtobool + assert strtobool("true") is True + assert strtobool("tRuE") is True + assert strtobool("1") is True + assert strtobool("t") is True + assert strtobool("on") is True + + assert strtobool("false") is False + assert strtobool("FaLse") is False + assert strtobool("0") is False + assert strtobool("f") is False + assert strtobool("off") is False + + with pytest.raises(ValueError): + strtobool("tru") + strtobool("fal") + strtobool("xxx") + strtobool(".") + + assert strtobool(None, False) is False + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])