You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

314 lines
12 KiB

# Copyright 2014 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import os
import mock
import unittest2
from oauth2client import _helpers
from oauth2client import client
from oauth2client import crypt
from oauth2client import service_account
def data_filename(filename):
return os.path.join(os.path.dirname(__file__), 'data', filename)
def datafile(filename):
with open(data_filename(filename), 'rb') as file_obj:
return file_obj.read()
class Test__bad_pkcs12_key_as_pem(unittest2.TestCase):
def test_fails(self):
with self.assertRaises(NotImplementedError):
crypt._bad_pkcs12_key_as_pem()
class Test_pkcs12_key_as_pem(unittest2.TestCase):
def _make_svc_account_creds(self, private_key_file='privatekey.p12'):
filename = data_filename(private_key_file)
credentials = (
service_account.ServiceAccountCredentials.from_p12_keyfile(
'some_account@example.com', filename,
scopes='read+write'))
credentials._kwargs['sub'] = 'joe@example.org'
return credentials
def _succeeds_helper(self, password=None):
self.assertEqual(True, client.HAS_OPENSSL)
credentials = self._make_svc_account_creds()
if password is None:
password = credentials._private_key_password
pem_contents = crypt.pkcs12_key_as_pem(
credentials._private_key_pkcs12, password)
pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem')
pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem)
alternate_pem = datafile('pem_from_pkcs12_alternate.pem')
self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem])
def test_succeeds(self):
self._succeeds_helper()
def test_succeeds_with_unicode_password(self):
password = u'notasecret'
self._succeeds_helper(password)
class Test__verify_signature(unittest2.TestCase):
def test_success_single_cert(self):
cert_value = 'cert-value'
certs = [cert_value]
message = object()
signature = object()
verifier = mock.MagicMock()
verifier.verify = mock.MagicMock(name='verify', return_value=True)
with mock.patch('oauth2client.crypt.Verifier') as Verifier:
Verifier.from_string = mock.MagicMock(name='from_string',
return_value=verifier)
result = crypt._verify_signature(message, signature, certs)
self.assertEqual(result, None)
# Make sure our mocks were called as expected.
Verifier.from_string.assert_called_once_with(cert_value,
is_x509_cert=True)
verifier.verify.assert_called_once_with(message, signature)
def test_success_multiple_certs(self):
cert_value1 = 'cert-value1'
cert_value2 = 'cert-value2'
cert_value3 = 'cert-value3'
certs = [cert_value1, cert_value2, cert_value3]
message = object()
signature = object()
verifier = mock.MagicMock()
# Use side_effect to force all 3 cert values to be used by failing
# to verify on the first two.
verifier.verify = mock.MagicMock(name='verify',
side_effect=[False, False, True])
with mock.patch('oauth2client.crypt.Verifier') as Verifier:
Verifier.from_string = mock.MagicMock(name='from_string',
return_value=verifier)
result = crypt._verify_signature(message, signature, certs)
self.assertEqual(result, None)
# Make sure our mocks were called three times.
expected_from_string_calls = [
mock.call(cert_value1, is_x509_cert=True),
mock.call(cert_value2, is_x509_cert=True),
mock.call(cert_value3, is_x509_cert=True),
]
self.assertEqual(Verifier.from_string.mock_calls,
expected_from_string_calls)
expected_verify_calls = [mock.call(message, signature)] * 3
self.assertEqual(verifier.verify.mock_calls,
expected_verify_calls)
def test_failure(self):
cert_value = 'cert-value'
certs = [cert_value]
message = object()
signature = object()
verifier = mock.MagicMock()
verifier.verify = mock.MagicMock(name='verify', return_value=False)
with mock.patch('oauth2client.crypt.Verifier') as Verifier:
Verifier.from_string = mock.MagicMock(name='from_string',
return_value=verifier)
with self.assertRaises(crypt.AppIdentityError):
crypt._verify_signature(message, signature, certs)
# Make sure our mocks were called as expected.
Verifier.from_string.assert_called_once_with(cert_value,
is_x509_cert=True)
verifier.verify.assert_called_once_with(message, signature)
class Test__check_audience(unittest2.TestCase):
def test_null_audience(self):
result = crypt._check_audience(None, None)
self.assertEqual(result, None)
def test_success(self):
audience = 'audience'
payload_dict = {'aud': audience}
result = crypt._check_audience(payload_dict, audience)
# No exception and no result.
self.assertEqual(result, None)
def test_missing_aud(self):
audience = 'audience'
payload_dict = {}
with self.assertRaises(crypt.AppIdentityError):
crypt._check_audience(payload_dict, audience)
def test_wrong_aud(self):
audience1 = 'audience1'
audience2 = 'audience2'
self.assertNotEqual(audience1, audience2)
payload_dict = {'aud': audience1}
with self.assertRaises(crypt.AppIdentityError):
crypt._check_audience(payload_dict, audience2)
class Test__verify_time_range(unittest2.TestCase):
def _exception_helper(self, payload_dict):
exception_caught = None
try:
crypt._verify_time_range(payload_dict)
except crypt.AppIdentityError as exc:
exception_caught = exc
return exception_caught
def test_without_issued_at(self):
payload_dict = {}
exception_caught = self._exception_helper(payload_dict)
self.assertNotEqual(exception_caught, None)
self.assertTrue(str(exception_caught).startswith(
'No iat field in token'))
def test_without_expiration(self):
payload_dict = {'iat': 'iat'}
exception_caught = self._exception_helper(payload_dict)
self.assertNotEqual(exception_caught, None)
self.assertTrue(str(exception_caught).startswith(
'No exp field in token'))
def test_with_bad_token_lifetime(self):
current_time = 123456
payload_dict = {
'iat': 'iat',
'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS + 1,
}
with mock.patch('oauth2client.crypt.time') as time:
time.time = mock.MagicMock(name='time',
return_value=current_time)
exception_caught = self._exception_helper(payload_dict)
self.assertNotEqual(exception_caught, None)
self.assertTrue(str(exception_caught).startswith(
'exp field too far in future'))
def test_with_issued_at_in_future(self):
current_time = 123456
payload_dict = {
'iat': current_time + crypt.CLOCK_SKEW_SECS + 1,
'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1,
}
with mock.patch('oauth2client.crypt.time') as time:
time.time = mock.MagicMock(name='time',
return_value=current_time)
exception_caught = self._exception_helper(payload_dict)
self.assertNotEqual(exception_caught, None)
self.assertTrue(str(exception_caught).startswith(
'Token used too early'))
def test_with_expiration_in_the_past(self):
current_time = 123456
payload_dict = {
'iat': current_time,
'exp': current_time - crypt.CLOCK_SKEW_SECS - 1,
}
with mock.patch('oauth2client.crypt.time') as time:
time.time = mock.MagicMock(name='time',
return_value=current_time)
exception_caught = self._exception_helper(payload_dict)
self.assertNotEqual(exception_caught, None)
self.assertTrue(str(exception_caught).startswith(
'Token used too late'))
def test_success(self):
current_time = 123456
payload_dict = {
'iat': current_time,
'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1,
}
with mock.patch('oauth2client.crypt.time') as time:
time.time = mock.MagicMock(name='time',
return_value=current_time)
exception_caught = self._exception_helper(payload_dict)
self.assertEqual(exception_caught, None)
class Test_verify_signed_jwt_with_certs(unittest2.TestCase):
def test_jwt_no_segments(self):
exception_caught = None
try:
crypt.verify_signed_jwt_with_certs(b'', None)
except crypt.AppIdentityError as exc:
exception_caught = exc
self.assertNotEqual(exception_caught, None)
self.assertTrue(str(exception_caught).startswith(
'Wrong number of segments in token'))
def test_jwt_payload_bad_json(self):
header = signature = b''
payload = base64.b64encode(b'{BADJSON')
jwt = b'.'.join([header, payload, signature])
exception_caught = None
try:
crypt.verify_signed_jwt_with_certs(jwt, None)
except crypt.AppIdentityError as exc:
exception_caught = exc
self.assertNotEqual(exception_caught, None)
self.assertTrue(str(exception_caught).startswith(
'Can\'t parse token'))
@mock.patch('oauth2client.crypt._check_audience')
@mock.patch('oauth2client.crypt._verify_time_range')
@mock.patch('oauth2client.crypt._verify_signature')
def test_success(self, verify_sig, verify_time, check_aud):
certs = mock.MagicMock()
cert_values = object()
certs.values = mock.MagicMock(name='values',
return_value=cert_values)
audience = object()
header = b'header'
signature_bytes = b'signature'
signature = base64.b64encode(signature_bytes)
payload_dict = {'a': 'b'}
payload = base64.b64encode(b'{"a": "b"}')
jwt = b'.'.join([header, payload, signature])
result = crypt.verify_signed_jwt_with_certs(
jwt, certs, audience=audience)
self.assertEqual(result, payload_dict)
message_to_sign = header + b'.' + payload
verify_sig.assert_called_once_with(
message_to_sign, signature_bytes, cert_values)
verify_time.assert_called_once_with(payload_dict)
check_aud.assert_called_once_with(payload_dict, audience)
certs.values.assert_called_once_with()