summaryrefslogtreecommitdiff
path: root/src/couch/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/couch/utils.py')
-rw-r--r--src/couch/utils.py58
1 files changed, 24 insertions, 34 deletions
diff --git a/src/couch/utils.py b/src/couch/utils.py
index 1cd21d8..f0883a6 100644
--- a/src/couch/utils.py
+++ b/src/couch/utils.py
@@ -1,26 +1,15 @@
# -*- coding: utf-8 -*-
# Based on py-couchdb (https://github.com/histrio/py-couchdb)
-
+from typing import Tuple, Union, Dict, List, Any
import json
import sys
+import requests
+from urllib.parse import unquote as _unquote
+from urllib.parse import urlunsplit, urlsplit
-if sys.version_info[0] == 3:
- from urllib.parse import unquote as _unquote
- from urllib.parse import urlunsplit, urlsplit
-
- string_type = str
- bytes_type = bytes
-
- from functools import reduce
-
-else:
- from urllib import unquote as _unquote
- from urlparse import urlunsplit, urlsplit
-
- string_type = unicode # noqa: F821
- bytes_type = str
+from functools import reduce
URLSPLITTER = '/'
@@ -28,7 +17,7 @@ URLSPLITTER = '/'
json_encoder = json.JSONEncoder()
-def extract_credentials(url):
+def extract_credentials(url: str) -> Tuple[str, Union[Tuple[str, str], None]]:
"""
Extract authentication (user name and password) credentials from the
given URL.
@@ -46,19 +35,20 @@ def extract_credentials(url):
if '@' in netloc:
creds, netloc = netloc.split('@')
credentials = tuple(_unquote(i) for i in creds.split(':'))
- parts = list(parts)
- parts[1] = netloc
- else:
- credentials = None
- return urlunsplit(parts), credentials
+ parts_list = list(parts)
+ parts_list[1] = netloc
+ return urlunsplit(parts_list), (credentials[0], credentials[1])
+
+ parts_list = list(parts)
+ return urlunsplit(parts_list), None
-def _join(head, tail):
+def _join(head: str, tail: str) -> str:
parts = [head.rstrip(URLSPLITTER), tail.lstrip(URLSPLITTER)]
return URLSPLITTER.join(parts)
-def urljoin(base, *path):
+def urljoin(base: str, *path: str) -> str:
"""
Assemble a uri based on a base, any number of path segments, and query
string parameters.
@@ -87,18 +77,18 @@ def urljoin(base, *path):
"""
return reduce(_join, path, base)
-
-def as_json(response):
+# Probably bugs here
+def as_json(response: requests.models.Response) -> Union[Dict[str, Any], None]:
if "application/json" in response.headers['content-type']:
response_src = response.content.decode('utf-8')
if response.content != b'':
- return json.loads(response_src)
+ ret: Dict[str, Any] = json.loads(response_src)
+ return ret
else:
return response_src
return None
-
-def _path_from_name(name, type):
+def _path_from_name(name: str, type: str) -> List[str]:
"""
Expand a 'design/foo' style name to its full path as a list of
segments.
@@ -114,7 +104,7 @@ def _path_from_name(name, type):
return ['_design', design, type, name]
-def encode_view_options(options):
+def encode_view_options(options: Dict[str, Any]) -> Dict[str, Any]:
"""
Encode any items in the options dict that are sent as a JSON string to a
view/list function.
@@ -138,13 +128,13 @@ def encode_view_options(options):
return retval
-def force_bytes(data, encoding="utf-8"):
- if isinstance(data, string_type):
+def force_bytes(data: Union[str, bytes], encoding: str = "utf-8") -> bytes:
+ if isinstance(data, str):
data = data.encode(encoding)
return data
-def force_text(data, encoding="utf-8"):
- if isinstance(data, bytes_type):
+def force_text(data: Union[str, bytes], encoding: str = "utf-8") -> str:
+ if isinstance(data, bytes):
data = data.decode(encoding)
return data