You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
294 lines
13 KiB
294 lines
13 KiB
# Copyright 2014 Google Inc. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for oauth2client.contrib.xsrfutil."""
|
|
|
|
import base64
|
|
|
|
import mock
|
|
import unittest2
|
|
|
|
from oauth2client import _helpers
|
|
from oauth2client.contrib import xsrfutil
|
|
|
|
# Jan 17 2008, 5:40PM
|
|
TEST_KEY = b'test key'
|
|
# Jan. 17, 2008 22:40:32.081230 UTC
|
|
TEST_TIME = 1200609642081230
|
|
TEST_USER_ID_1 = 123832983
|
|
TEST_USER_ID_2 = 938297432
|
|
TEST_ACTION_ID_1 = b'some_action'
|
|
TEST_ACTION_ID_2 = b'some_other_action'
|
|
TEST_EXTRA_INFO_1 = b'extra_info_1'
|
|
TEST_EXTRA_INFO_2 = b'more_extra_info'
|
|
|
|
|
|
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
|
|
|
|
|
|
class Test_generate_token(unittest2.TestCase):
|
|
|
|
def test_bad_positional(self):
|
|
# Need 2 positional arguments.
|
|
with self.assertRaises(TypeError):
|
|
xsrfutil.generate_token(None)
|
|
# At most 2 positional arguments.
|
|
with self.assertRaises(TypeError):
|
|
xsrfutil.generate_token(None, None, None)
|
|
|
|
def test_it(self):
|
|
digest = b'foobar'
|
|
digester = mock.MagicMock()
|
|
digester.digest = mock.MagicMock(name='digest', return_value=digest)
|
|
with mock.patch('oauth2client.contrib.xsrfutil.hmac') as hmac:
|
|
hmac.new = mock.MagicMock(name='new', return_value=digester)
|
|
token = xsrfutil.generate_token(TEST_KEY,
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_1,
|
|
when=TEST_TIME)
|
|
hmac.new.assert_called_once_with(TEST_KEY)
|
|
digester.digest.assert_called_once_with()
|
|
|
|
expected_digest_calls = [
|
|
mock.call.update(_helpers._to_bytes(str(TEST_USER_ID_1))),
|
|
mock.call.update(xsrfutil.DELIMITER),
|
|
mock.call.update(TEST_ACTION_ID_1),
|
|
mock.call.update(xsrfutil.DELIMITER),
|
|
mock.call.update(_helpers._to_bytes(str(TEST_TIME))),
|
|
]
|
|
self.assertEqual(digester.method_calls, expected_digest_calls)
|
|
|
|
expected_token_as_bytes = (digest + xsrfutil.DELIMITER +
|
|
_helpers._to_bytes(str(TEST_TIME)))
|
|
expected_token = base64.urlsafe_b64encode(
|
|
expected_token_as_bytes)
|
|
self.assertEqual(token, expected_token)
|
|
|
|
def test_with_system_time(self):
|
|
digest = b'foobar'
|
|
curr_time = 1440449755.74
|
|
digester = mock.MagicMock()
|
|
digester.digest = mock.MagicMock(name='digest', return_value=digest)
|
|
with mock.patch('oauth2client.contrib.xsrfutil.hmac') as hmac:
|
|
hmac.new = mock.MagicMock(name='new', return_value=digester)
|
|
|
|
with mock.patch('oauth2client.contrib.xsrfutil.time') as time:
|
|
time.time = mock.MagicMock(name='time', return_value=curr_time)
|
|
# when= is omitted
|
|
token = xsrfutil.generate_token(TEST_KEY,
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_1)
|
|
|
|
hmac.new.assert_called_once_with(TEST_KEY)
|
|
time.time.assert_called_once_with()
|
|
digester.digest.assert_called_once_with()
|
|
|
|
expected_digest_calls = [
|
|
mock.call.update(_helpers._to_bytes(str(TEST_USER_ID_1))),
|
|
mock.call.update(xsrfutil.DELIMITER),
|
|
mock.call.update(TEST_ACTION_ID_1),
|
|
mock.call.update(xsrfutil.DELIMITER),
|
|
mock.call.update(_helpers._to_bytes(str(int(curr_time)))),
|
|
]
|
|
self.assertEqual(digester.method_calls, expected_digest_calls)
|
|
|
|
expected_token_as_bytes = (
|
|
digest + xsrfutil.DELIMITER +
|
|
_helpers._to_bytes(str(int(curr_time))))
|
|
expected_token = base64.urlsafe_b64encode(
|
|
expected_token_as_bytes)
|
|
self.assertEqual(token, expected_token)
|
|
|
|
|
|
class Test_validate_token(unittest2.TestCase):
|
|
|
|
def test_bad_positional(self):
|
|
# Need 3 positional arguments.
|
|
with self.assertRaises(TypeError):
|
|
xsrfutil.validate_token(None, None)
|
|
# At most 3 positional arguments.
|
|
with self.assertRaises(TypeError):
|
|
xsrfutil.validate_token(None, None, None, None)
|
|
|
|
def test_no_token(self):
|
|
key = token = user_id = None
|
|
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
|
|
|
|
def test_token_not_valid_base64(self):
|
|
key = user_id = None
|
|
token = b'a' # Bad padding
|
|
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
|
|
|
|
def test_token_non_integer(self):
|
|
key = user_id = None
|
|
token = base64.b64encode(b'abc' + xsrfutil.DELIMITER + b'xyz')
|
|
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
|
|
|
|
def test_token_too_old_implicit_current_time(self):
|
|
token_time = 123456789
|
|
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1
|
|
|
|
key = user_id = None
|
|
token = base64.b64encode(_helpers._to_bytes(str(token_time)))
|
|
with mock.patch('oauth2client.contrib.xsrfutil.time') as time:
|
|
time.time = mock.MagicMock(name='time', return_value=curr_time)
|
|
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
|
|
time.time.assert_called_once_with()
|
|
|
|
def test_token_too_old_explicit_current_time(self):
|
|
token_time = 123456789
|
|
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1
|
|
|
|
key = user_id = None
|
|
token = base64.b64encode(_helpers._to_bytes(str(token_time)))
|
|
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
|
|
current_time=curr_time))
|
|
|
|
def test_token_length_differs_from_generated(self):
|
|
token_time = 123456789
|
|
# Make sure it isn't too old.
|
|
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
|
|
|
|
key = object()
|
|
user_id = object()
|
|
action_id = object()
|
|
token = base64.b64encode(_helpers._to_bytes(str(token_time)))
|
|
generated_token = b'a'
|
|
# Make sure the token length comparison will fail.
|
|
self.assertNotEqual(len(token), len(generated_token))
|
|
|
|
with mock.patch('oauth2client.contrib.xsrfutil.generate_token',
|
|
return_value=generated_token) as gen_tok:
|
|
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
|
|
current_time=curr_time,
|
|
action_id=action_id))
|
|
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
|
|
when=token_time)
|
|
|
|
def test_token_differs_from_generated_but_same_length(self):
|
|
token_time = 123456789
|
|
# Make sure it isn't too old.
|
|
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
|
|
|
|
key = object()
|
|
user_id = object()
|
|
action_id = object()
|
|
token = base64.b64encode(_helpers._to_bytes(str(token_time)))
|
|
# It is encoded as b'MTIzNDU2Nzg5', which has length 12.
|
|
generated_token = b'M' * 12
|
|
# Make sure the token length comparison will succeed, but the token
|
|
# comparison will fail.
|
|
self.assertEqual(len(token), len(generated_token))
|
|
self.assertNotEqual(token, generated_token)
|
|
|
|
with mock.patch('oauth2client.contrib.xsrfutil.generate_token',
|
|
return_value=generated_token) as gen_tok:
|
|
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
|
|
current_time=curr_time,
|
|
action_id=action_id))
|
|
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
|
|
when=token_time)
|
|
|
|
def test_success(self):
|
|
token_time = 123456789
|
|
# Make sure it isn't too old.
|
|
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
|
|
|
|
key = object()
|
|
user_id = object()
|
|
action_id = object()
|
|
token = base64.b64encode(_helpers._to_bytes(str(token_time)))
|
|
with mock.patch('oauth2client.contrib.xsrfutil.generate_token',
|
|
return_value=token) as gen_tok:
|
|
self.assertTrue(xsrfutil.validate_token(key, token, user_id,
|
|
current_time=curr_time,
|
|
action_id=action_id))
|
|
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
|
|
when=token_time)
|
|
|
|
|
|
class XsrfUtilTests(unittest2.TestCase):
|
|
"""Test xsrfutil functions."""
|
|
|
|
def testGenerateAndValidateToken(self):
|
|
"""Test generating and validating a token."""
|
|
token = xsrfutil.generate_token(TEST_KEY,
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_1,
|
|
when=TEST_TIME)
|
|
|
|
# Check that the token is considered valid when it should be.
|
|
self.assertTrue(xsrfutil.validate_token(TEST_KEY,
|
|
token,
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_1,
|
|
current_time=TEST_TIME))
|
|
|
|
# Should still be valid 15 minutes later.
|
|
later15mins = TEST_TIME + 15 * 60
|
|
self.assertTrue(xsrfutil.validate_token(TEST_KEY,
|
|
token,
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_1,
|
|
current_time=later15mins))
|
|
|
|
# But not if beyond the timeout.
|
|
later2hours = TEST_TIME + 2 * 60 * 60
|
|
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
|
token,
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_1,
|
|
current_time=later2hours))
|
|
|
|
# Or if the key is different.
|
|
self.assertFalse(xsrfutil.validate_token('another key',
|
|
token,
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_1,
|
|
current_time=later15mins))
|
|
|
|
# Or the user ID....
|
|
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
|
token,
|
|
TEST_USER_ID_2,
|
|
action_id=TEST_ACTION_ID_1,
|
|
current_time=later15mins))
|
|
|
|
# Or the action ID...
|
|
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
|
token,
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_2,
|
|
current_time=later15mins))
|
|
|
|
# Invalid when truncated
|
|
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
|
token[:-1],
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_1,
|
|
current_time=later15mins))
|
|
|
|
# Invalid with extra garbage
|
|
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
|
token + b'x',
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_1,
|
|
current_time=later15mins))
|
|
|
|
# Invalid with token of None
|
|
self.assertFalse(xsrfutil.validate_token(TEST_KEY,
|
|
None,
|
|
TEST_USER_ID_1,
|
|
action_id=TEST_ACTION_ID_1))
|