1 /*
2  * Copyright (C) 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.net.module.util;
18 
19 import static android.system.OsConstants.IPPROTO_ICMPV6;
20 import static android.system.OsConstants.IPPROTO_TCP;
21 import static android.system.OsConstants.IPPROTO_UDP;
22 
23 import java.net.Inet6Address;
24 import java.net.InetAddress;
25 import java.nio.ByteBuffer;
26 import java.nio.ShortBuffer;
27 
28 /**
29  * @hide
30  */
31 public class IpUtils {
32     /**
33      * Converts a signed short value to an unsigned int value.  Needed
34      * because Java does not have unsigned types.
35      */
intAbs(short v)36     private static int intAbs(short v) {
37         return v & 0xFFFF;
38     }
39 
40     /**
41      * Performs an IP checksum (used in IP header and across UDP
42      * payload) on the specified portion of a ByteBuffer.  The seed
43      * allows the checksum to commence with a specified value.
44      */
checksum(ByteBuffer buf, int seed, int start, int end)45     private static int checksum(ByteBuffer buf, int seed, int start, int end) {
46         int sum = seed;
47         final int bufPosition = buf.position();
48 
49         // set position of original ByteBuffer, so that the ShortBuffer
50         // will be correctly initialized
51         buf.position(start);
52         ShortBuffer shortBuf = buf.asShortBuffer();
53 
54         // re-set ByteBuffer position
55         buf.position(bufPosition);
56 
57         final int numShorts = (end - start) / 2;
58         for (int i = 0; i < numShorts; i++) {
59             sum += intAbs(shortBuf.get(i));
60         }
61         start += numShorts * 2;
62 
63         // see if a singleton byte remains
64         if (end != start) {
65             short b = buf.get(start);
66 
67             // make it unsigned
68             if (b < 0) {
69                 b += 256;
70             }
71 
72             sum += b * 256;
73         }
74 
75         sum = ((sum >> 16) & 0xFFFF) + (sum & 0xFFFF);
76         sum = ((sum + ((sum >> 16) & 0xFFFF)) & 0xFFFF);
77         int negated = ~sum;
78         return intAbs((short) negated);
79     }
80 
pseudoChecksumIPv4( ByteBuffer buf, int headerOffset, int protocol, int transportLen)81     private static int pseudoChecksumIPv4(
82             ByteBuffer buf, int headerOffset, int protocol, int transportLen) {
83         int partial = protocol + transportLen;
84         partial += intAbs(buf.getShort(headerOffset + 12));
85         partial += intAbs(buf.getShort(headerOffset + 14));
86         partial += intAbs(buf.getShort(headerOffset + 16));
87         partial += intAbs(buf.getShort(headerOffset + 18));
88         return partial;
89     }
90 
pseudoChecksumIPv6( ByteBuffer buf, int headerOffset, int protocol, int transportLen)91     private static int pseudoChecksumIPv6(
92             ByteBuffer buf, int headerOffset, int protocol, int transportLen) {
93         int partial = protocol + transportLen;
94         for (int offset = 8; offset < 40; offset += 2) {
95             partial += intAbs(buf.getShort(headerOffset + offset));
96         }
97         return partial;
98     }
99 
ipversion(ByteBuffer buf, int headerOffset)100     private static byte ipversion(ByteBuffer buf, int headerOffset) {
101         return (byte) ((buf.get(headerOffset) & (byte) 0xf0) >> 4);
102    }
103 
ipChecksum(ByteBuffer buf, int headerOffset)104     public static short ipChecksum(ByteBuffer buf, int headerOffset) {
105         byte ihl = (byte) (buf.get(headerOffset) & 0x0f);
106         return (short) checksum(buf, 0, headerOffset, headerOffset + ihl * 4);
107     }
108 
transportChecksum(ByteBuffer buf, int protocol, int ipOffset, int transportOffset, int transportLen)109     private static short transportChecksum(ByteBuffer buf, int protocol,
110             int ipOffset, int transportOffset, int transportLen) {
111         if (transportLen < 0) {
112             throw new IllegalArgumentException("Transport length < 0: " + transportLen);
113         }
114         int sum;
115         byte ver = ipversion(buf, ipOffset);
116         if (ver == 4) {
117             sum = pseudoChecksumIPv4(buf, ipOffset, protocol, transportLen);
118         } else if (ver == 6) {
119             sum = pseudoChecksumIPv6(buf, ipOffset, protocol, transportLen);
120         } else {
121             throw new UnsupportedOperationException("Checksum must be IPv4 or IPv6");
122         }
123 
124         sum = checksum(buf, sum, transportOffset, transportOffset + transportLen);
125         if (protocol == IPPROTO_UDP && sum == 0) {
126             sum = (short) 0xffff;
127         }
128         return (short) sum;
129     }
130 
131     /**
132      * Calculate the UDP checksum for an UDP packet.
133      */
udpChecksum(ByteBuffer buf, int ipOffset, int transportOffset)134     public static short udpChecksum(ByteBuffer buf, int ipOffset, int transportOffset) {
135         int transportLen = intAbs(buf.getShort(transportOffset + 4));
136         return transportChecksum(buf, IPPROTO_UDP, ipOffset, transportOffset, transportLen);
137     }
138 
139     /**
140      * Calculate the TCP checksum for a TCP packet.
141      */
tcpChecksum(ByteBuffer buf, int ipOffset, int transportOffset, int transportLen)142     public static short tcpChecksum(ByteBuffer buf, int ipOffset, int transportOffset,
143             int transportLen) {
144         return transportChecksum(buf, IPPROTO_TCP, ipOffset, transportOffset, transportLen);
145     }
146 
147     /**
148      * Calculate the ICMPv6 checksum for an ICMPv6 packet.
149      */
icmpv6Checksum(ByteBuffer buf, int ipOffset, int transportOffset, int transportLen)150     public static short icmpv6Checksum(ByteBuffer buf, int ipOffset, int transportOffset,
151             int transportLen) {
152         return transportChecksum(buf, IPPROTO_ICMPV6, ipOffset, transportOffset, transportLen);
153     }
154 
addressAndPortToString(InetAddress address, int port)155     public static String addressAndPortToString(InetAddress address, int port) {
156         return String.format(
157                 (address instanceof Inet6Address) ? "[%s]:%d" : "%s:%d",
158                 address.getHostAddress(), port);
159     }
160 
isValidUdpOrTcpPort(int port)161     public static boolean isValidUdpOrTcpPort(int port) {
162         return port > 0 && port < 65536;
163     }
164 }
165