netlib.cpp

Go to the documentation of this file.
00001 /*
00002  * netlib.cpp
00003  *
00004  * Copyright (C) 2007-2009  Thomas A. Vaughan
00005  * All rights reserved.
00006  *
00007  *
00008  * Redistribution and use in source and binary forms, with or without
00009  * modification, are permitted provided that the following conditions are met:
00010  *     * Redistributions of source code must retain the above copyright
00011  *       notice, this list of conditions and the following disclaimer.
00012  *     * Redistributions in binary form must reproduce the above copyright
00013  *       notice, this list of conditions and the following disclaimer in the
00014  *       documentation and/or other materials provided with the distribution.
00015  *     * Neither the name of the <organization> nor the
00016  *       names of its contributors may be used to endorse or promote products
00017  *       derived from this software without specific prior written permission.
00018  *
00019  * THIS SOFTWARE IS PROVIDED BY THOMAS A. VAUGHAN ''AS IS'' AND ANY
00020  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00021  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00022  * DISCLAIMED. IN NO EVENT SHALL THOMAS A. VAUGHAN BE LIABLE FOR ANY
00023  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00024  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00025  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00026  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00028  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029  *
00030  *
00031  * Implementation of the networking library.  See netlib.h
00032  */
00033 
00034 // includes --------------------------------------------------------------------
00035 #include "netlib.h"             // always include our own header first!
00036 #include "wavesock.h"
00037 
00038 #include <deque>
00039 
00040 #ifndef _XOPEN_SOURCE
00041 #define _XOPEN_SOURCE 600
00042 #endif  // _XOPEN_SOURCE
00043 
00044 #include <string.h>
00045 
00046 #include "common/wave_ex.h"
00047 #include "perf/perf.h"
00048 #include "util/parsing.h"
00049 
00050 
00051 namespace netlib {
00052 
00053 
00054 // use a small size for testing, large size for production
00055 static const int s_chunkSize            = 8192;
00056 
00057 static const int s_maxHeaderLine        = 64;
00058 
00059 struct request_t {
00060         // constructor, manipulators
00061         request_t(void) throw() { this->clear(); }
00062         void clear(void) throw() {
00063                         connId = 0;
00064                         msgbuf = NULL;
00065                 }
00066         bool is_empty(void) const throw() {
00067                         return (!connId && !msgbuf);
00068                 }
00069 
00070         // data fields
00071         conn_id_t                       connId;
00072         smart_ptr<MessageBuffer>        msgbuf;
00073 };
00074 
00075 
00076 
00077 // TODO: don't require memory alloc/free!  Keep a free list
00078 typedef std::deque<request_t> message_queue_t;
00079 
00080 
00081 
00082 // conn_rec_t : connection record
00083 struct conn_rec_t {
00084         // constructor, manipulators
00085         conn_rec_t(void) throw() : socket(-1) { this->clear(); }
00086         ~conn_rec_t(void) { this->clear(); }
00087         void clear(void) throw() { 
00088                         if (wsIsValidSocket(socket)) {
00089                                 DPRINTF("Closing down socket!");
00090                                 DPRINTF("Connection id = 0x%lx", (long) conn_id);
00091                                 wsCloseSocket(socket);
00092                         }
00093 
00094                         conn_id = 0;
00095                         local = 0;
00096                         socket = -1;
00097                         need_bytes = -1;
00098                         type = eType_Invalid;
00099                         msgbuf = NULL;
00100                         message_queue.clear();
00101                         send_byte = -1;
00102                         buffer[0] = 0;
00103                         buff_idx = -1;  // empty
00104                         buff_len = 0;
00105                         header[0] = 0;
00106                         head_idx = 0;
00107                         udpFrom.clear();
00108                         address.clear();
00109                 }
00110         void dump(IN const char * text) const throw() {
00111                         DPRINTF("%s", text);
00112                         DPRINTF("  connId: 0x%04lx", (long) conn_id);
00113                         DPRINTF("  socket: %d", socket);
00114                         DPRINTF("  type: %d", type);
00115                         address.dump(text);
00116                 }
00117 
00118         // data fields
00119         conn_id_t       conn_id;
00120         conn_id_t       local;          // local peer connection (UDP)
00121         int             socket;
00122         long            need_bytes;     // bytes needed to complete message
00123         smart_ptr<MessageBuffer> msgbuf;// long-lived message buffer
00124         eConnectionType type;           // what sort of connection?
00125         message_queue_t message_queue;  // pending messages to write
00126         address_t       address;        // where messages go
00127         address_t       udpFrom;        // for received UDP messages
00128         long            send_byte;      // current send byte
00129         char            buffer[s_chunkSize];    // buffer for reading
00130         int             buff_idx;       // where in buffer are we?
00131         ssize_t         buff_len;       // how much did we read?
00132         char            header[s_maxHeaderLine];
00133         int             head_idx;       // index into header line
00134 };
00135 
00136 // connection ID --> connection record
00137 typedef std::map<conn_id_t, smart_ptr<conn_rec_t> > conn_map_t;
00138 
00139 static conn_map_t s_connection_map;
00140 
00141 
00142 
00143 // stats!
00144 static dword_t s_messagesSent                   = 0;
00145 static dword_t s_messagesReceived               = 0;
00146 static qword_t s_bytesWritten                   = 0;
00147 static qword_t s_bytesRead                      = 0;
00148 
00149 
00150 // connection type names
00151 
00152 
00153 ////////////////////////////////////////////////////////////////////////////////
00154 //
00155 //      static helper methods
00156 //
00157 ////////////////////////////////////////////////////////////////////////////////
00158 
00159 static const char *
00160 getTypeName
00161 (
00162 IN eConnectionType type
00163 )
00164 {
00165         switch (type) {
00166 
00167         case eType_TCP:
00168                 return "TCP client";
00169 
00170         case eType_UDPLocal:
00171                 return "Local UDP port";
00172 
00173         case eType_UDPRemote:
00174                 return "Remote UDP port";
00175 
00176         case eType_TCPListener:
00177                 return "Local TCP Listener";
00178 
00179         default:
00180                 break;
00181         }
00182         return "Unknown connection type!";
00183 }
00184 
00185 
00186 
00187 static conn_id_t
00188 getNewConnectionId
00189 (
00190 void
00191 )
00192 {
00193         // originally I was just using a dumb counter, to avoid having
00194         // to look for collisions etc.  But counters can wrap!  So I'm
00195         // using random numbers
00196         static const dword_t s_dwMax = 0x10000 - 1;
00197         for (;;) {
00198                 conn_id_t conn_id = 1 + (rand() % s_dwMax);
00199                 if (s_connection_map.end() == s_connection_map.find(conn_id))
00200                         return conn_id;
00201         }
00202 }
00203 
00204 
00205 
00206 static void
00207 dumpErrorInfo
00208 (
00209 IN const char * msg
00210 )
00211 {
00212         const int s_bufsize = 256;
00213         char buffer[s_bufsize];
00214 
00215         wsGetErrorMessage(buffer, s_bufsize);
00216 
00217         DPRINTF("%s", msg);
00218         DPRINTF("%s", buffer);
00219 }
00220 
00221 
00222 
00223 static void
00224 verify
00225 (
00226 IN bool isOK,
00227 IN const char * msg
00228 )
00229 {
00230         if (isOK)
00231                 return;         // no problem!
00232 
00233         // if we're here, there is a big problem!
00234         dumpErrorInfo(msg);
00235         ASSERT(false, "halting");
00236 }
00237 
00238 
00239 
00240 static void
00241 verifyThrow
00242 (
00243 IN bool isOK,
00244 IN const char * msg
00245 )
00246 {
00247         if (isOK)
00248                 return;         // no problem!
00249 
00250         const int s_bufsize = 256;
00251         char buffer[s_bufsize];
00252         wsGetErrorMessage(buffer, s_bufsize);
00253 
00254         ASSERT(msg, "null");
00255         DPRINTF("Error!  '%s' on '%s'", buffer, msg);
00256 
00257         WAVE_EX(wex);
00258         wex << msg << "\n";
00259         wex << "Error: " << buffer;
00260 }
00261 
00262 
00263 
00264 static conn_rec_t *
00265 getConnectionRecord
00266 (
00267 IN conn_id_t conn_id
00268 )
00269 {
00270         ASSERT(conn_id, "null");
00271 
00272         conn_map_t::iterator i = s_connection_map.find(conn_id);
00273         if (s_connection_map.end() == i) {
00274                 DPRINTF("Connection ID 0x%lx not found!", (long) conn_id);
00275                 return NULL;
00276         }
00277 
00278         return i->second;
00279 }
00280 
00281 
00282 
00283 static bool
00284 readBuffer
00285 (
00286 IN conn_rec_t * rec
00287 )
00288 {
00289         // timer itself impacts timing!
00290         // perf::Timer timer("netlib::readBuffer");
00291         ASSERT(rec, "null");
00292 
00293         // reset
00294         rec->buffer[0] = 0;
00295         rec->buff_len = 0;
00296         rec->buff_idx = -1;
00297 
00298         // make nonblocking call to see if we have any data to read
00299         // NOTE: the type of read call depends on the type of socket!
00300         ssize_t bytes = -2;
00301         if (eType_TCP == rec->type) {
00302                 bytes = wsReceive(rec->socket, rec->buffer, s_chunkSize - 1);
00303         } else if (eType_UDPLocal == rec->type) {
00304                 // DPRINTF("Reading UDP packet...");
00305                 bytes = wsReceiveFrom(rec->socket, rec->buffer, s_chunkSize - 1,
00306                     rec->udpFrom);
00307         } else {
00308                 DPRINTF("Bad local connection type?  Disconnecting...");
00309                 bytes = 0;
00310         }
00311 
00312         // error?
00313         if (bytes < 0) {
00314                 if (eWS_Again == wsGetError()) {
00315                         // no problem, try again later
00316                         return false;
00317                 }
00318                 DPRINTF("Error receiving!  Disconnecting client...");
00319                 bytes = 0;
00320         }
00321 
00322         // client gave up?
00323         if (0 == bytes) {
00324                 DPRINTF("Client has disconnected");
00325                 closeConnection(rec->conn_id);
00326                 return false;
00327         }
00328 
00329         s_bytesRead += bytes;
00330 
00331         // received data!
00332         if (bytes >= s_chunkSize) {
00333                 DPRINTF("ERROR: client sent too many bytes!");
00334                 DPRINTF("Our buffer size is %d bytes",
00335                     s_chunkSize - 1);
00336                 DPRINTF("Client sent %ld bytes", (long) bytes);
00337                 DPRINTF("Truncating data!");
00338                 ASSERT(false, "HALT");  // this is bad
00339                 bytes = s_chunkSize - 1;
00340         }
00341         rec->buffer[bytes] = 0; // force null-termination
00342         rec->buff_len = bytes;
00343         rec->buff_idx = 0;
00344 
00345         return true;
00346 }
00347 
00348 
00349 
00350 static void
00351 parseHeaderLine
00352 (
00353 IN conn_rec_t * rec
00354 )
00355 {
00356         // at the moment, this routine is fast enough that the timer itself
00357         //   adds significant time!
00358         // perf::Timer timer("netlib::parseHeaderLine");
00359         ASSERT(rec, "null");
00360         ASSERT(rec->head_idx >= 0 && rec->head_idx < s_maxHeaderLine,
00361             "Bad header byte index: %d", rec->head_idx);
00362 
00363         // TODO: avoid allocations here!  (use of std::string)
00364 //      rec->dump("Parsing header");
00365 
00366         // end of line!  Interesting?
00367         rec->header[rec->head_idx] = 0; // null-terminate
00368         std::string key;
00369         const char * p = getNextTokenFromString(rec->header, key, eParse_None);
00370         //DPRINTF("key = '%s'", key.c_str());
00371 
00372         std::string val;
00373         getNextTokenFromString(p, val, eParse_None);
00374         //DPRINTF("val = '%s'", val.c_str());
00375 
00376         // now what?
00377         // At the moment, the header consists of a single line: the byte
00378         //  count (size), which is identified by a leading "s" character.
00379         if ("s" == key) {
00380                 rec->need_bytes = atol(val.c_str());
00381                 //DPRINTF("Message bytes: %ld", rec->need_bytes);
00382                 if (rec->need_bytes <= 0) {
00383                         DPRINTF("Bad byte count? %ld", rec->need_bytes);
00384                         rec->need_bytes = 0;
00385                 }
00386         } else {
00387                 //DPRINTF("Unknown message header key!");
00388         }
00389 }
00390 
00391 
00392 
00393 static bool
00394 handleData
00395 (
00396 IN conn_rec_t * rec,
00397 IO envelope_t& envelope,
00398 IO smart_ptr<MessageBuffer>& msgbuf
00399 )
00400 {
00401         // timer itself impacts timing!
00402         // perf::Timer timer("netlib::handleData");
00403         ASSERT(rec, "null");
00404         ASSERT(envelope.is_empty(), "not empty");
00405         ASSERT(!msgbuf, "should be null");
00406 
00407 //      rec->dump("Reading data");
00408 
00409         // DPRINTF("Got data to read!");
00410 
00411         // keep reading!
00412         const char * p = NULL;
00413         while (true) {
00414 
00415                 if (rec->buff_idx >= rec->buff_len) {
00416                         rec->buff_idx = -1;
00417                 }
00418 
00419                 // DPRINTF("Starting loop, idx = %d", rec->buff_idx);
00420                 if (rec->buff_idx < 0) {
00421                         // need to read buffer!
00422                         if (!readBuffer(rec))
00423                                 return false;   // couldn't read
00424                         if (rec->buff_idx < 0) {
00425                                 return false;   // failed to read anyway
00426                         }
00427                         //DPRINTF("  Read %d bytes", rec->buff_len);
00428                 }
00429                 p = rec->buffer + rec->buff_idx;
00430                 //DPRINTF("  After message read, idx=%d", rec->buff_idx);
00431                 //DPRINTF("  %d bytes remaining in buffer",
00432                 //    rec->buff_len - rec->buff_idx);
00433 
00434                 // parse message headers if necessary
00435                 const char * maxP = rec->buffer + rec->buff_len;
00436                 for (; p < maxP && rec->need_bytes < 0; ++p) {
00437                         
00438                         // are we getting a lot of garbage from client?
00439                         if (rec->head_idx >= s_maxHeaderLine - 1) {
00440                                 // too big!  Reset
00441                                 DPRINTF("Garbage from remote host!  Resetting");
00442                                 rec->head_idx = 0;
00443                         }
00444 
00445                         // push new byte to end of our header buffer
00446                         rec->header[rec->head_idx] = *p;
00447                         rec->head_idx++;
00448 
00449                         if ('\n' == *p) {
00450                                 parseHeaderLine(rec);
00451                                 rec->head_idx = 0;
00452                         } else if (!*p) {
00453                                 // null in header line?  Weird!
00454                                 rec->head_idx = 0;
00455                         }
00456                 }
00457 
00458                 // should be message data
00459                 int remain = rec->buff_len - (p - rec->buffer);
00460                 // DPRINTF("%d bytes remain", remain);
00461                 rec->buff_idx = p - rec->buffer;
00462 
00463                 if (!remain) {
00464 //                      DPRINTF("end of buffer");       // very common!
00465                         continue;       // end of buffer
00466                 }
00467 
00468                 // no point in proceeding if we aren't ready
00469                 if (rec->need_bytes < 0) {
00470                         DPRINTF("aren't ready");
00471                         ASSERT(!*p, "should be out of data");
00472 //                      rec->need_bytes = 0;
00473                         return false;   // ran out of data from read
00474                 }
00475 
00476                 // what we expected?
00477                 if (rec->need_bytes < remain) {
00478 //                      DPRINTF("ERROR: received more bytes than expected!");
00479 //                      DPRINTF("  expected: %ld", rec->need_bytes);
00480 //                      DPRINTF("  received: %d", remain);
00481 //                      DPRINTF("  truncating!!!");  - NO!  Not truncating...
00482                         // not a problem: we just take what we need
00483                         remain = rec->need_bytes;
00484                 }
00485 
00486                 // need to create buffer?
00487                 if (!rec->msgbuf) {
00488                         rec->msgbuf = MessageBuffer::create();
00489                         ASSERT(rec->msgbuf, "failed to create message buffer?");
00490                         rec->msgbuf->reserve(rec->need_bytes + 1);
00491                 }
00492 
00493                 // append
00494                 rec->msgbuf->appendData(p, remain);
00495 
00496                 // decrement
00497                 //DPRINTF("  Copied %d bytes...", remain);
00498                 rec->need_bytes -= remain;
00499                 rec->buff_idx = p + remain - rec->buffer;
00500                 if (rec->need_bytes < 1) {
00501                         //DPRINTF("Message is now complete!");
00502                         rec->msgbuf->close();
00503                         //DPRINTF("Message size: %ld bytes",
00504                         //    rec->msgbuf->getBytes());
00505 
00506                         // hand buffer over to message
00507                         msgbuf = rec->msgbuf;
00508 
00509                         // construct envelope information
00510                         envelope.fromConnId = rec->conn_id;
00511                         envelope.type = rec->type;
00512 
00513                         // Need to swap out for UDP!
00514                         // We read from local UDP port (of course), but client
00515                         // needs to know which remote UDP client sent this.
00516                         if (eType_UDPLocal == envelope.type) {
00517                                 envelope.type = eType_UDPRemote;
00518                                 envelope.fromConnId = 0;
00519                                 envelope.address = rec->udpFrom;
00520                         }
00521 
00522                         // give up ownership and clean up
00523                         rec->msgbuf = 0;
00524                         rec->need_bytes = -1;
00525                         s_messagesReceived++;
00526 
00527                         return true;
00528                 }
00529         }
00530 
00531         // nope
00532         return false;
00533 }
00534 
00535 
00536 
00537 static conn_id_t
00538 addConnectionRecord
00539 (
00540 IN eConnectionType type,
00541 IN int socket,
00542 IN const address_t& address
00543 )
00544 {
00545         ASSERT(eType_Invalid != type, "Bad type");
00546         ASSERT(socket > -2, "bad socket");
00547 
00548 //      DPRINTF("Creating connection record for '%s':%d ...", host, port);
00549 
00550         if (-1 == socket) {
00551                 ASSERT(eType_UDPRemote == type,
00552                     "Bad connection type (%d) for socket %d", type, socket);
00553         }
00554 
00555         // construct connection record and put in threadsafe map...
00556         smart_ptr<conn_rec_t> rec = new conn_rec_t;
00557         ASSERT(rec, "out of memory?");
00558         rec->socket = socket;
00559         rec->address = address;
00560         rec->type = type;
00561         rec->conn_id = getNewConnectionId();
00562 //      DPRINTF("  Assigning connection id = 0x%04lx", rec->conn_id);
00563 //      rec->dump("Just created");
00564 
00565         // add to map
00566         s_connection_map[rec->conn_id] = rec;
00567         ASSERT(2 == rec.get_ref_count(), "should have 2 refs!");
00568 
00569         //DPRINTF("Currently have %d connections", s_connection_map.size());
00570 
00571         return rec->conn_id;
00572 }
00573 
00574 
00575 
00576 static conn_id_t
00577 handleConnection
00578 (
00579 IN conn_rec_t * rec
00580 )
00581 {
00582         perf::Timer timer("netlib::handleConnection");
00583         ASSERT(rec, "null");
00584         ASSERT(eType_TCPListener == rec->type,
00585             "Requesting to listen on a non-listening socket?");
00586         ASSERT(wsIsValidSocket(rec->socket), "bad socket? %d", rec->socket);
00587 
00588         address_t address;
00589         int c = wsAccept(rec->socket, address);
00590         if (!wsIsValidSocket(c)) {
00591                 // nobody wanted to connect!
00592                 DPRINTF("Not a valid connection?");
00593                 return 0;
00594         }
00595 //      address.dump("New connection");
00596 
00597         return addConnectionRecord(eType_TCP, c, address);
00598 }
00599 
00600 
00601 
00602 static void
00603 writeMessage
00604 (
00605 IN conn_rec_t * rec
00606 )
00607 {
00608         perf::Timer timer("netlib::writeMessage");
00609         ASSERT(rec, "null");
00610 
00611         ASSERT(rec->message_queue.size(), "empty queue?");
00612 
00613         //DPRINTF("Have message to write!");
00614 
00615         // get the first message in the queue
00616         const request_t& req = rec->message_queue.front();
00617         ASSERT(!req.is_empty(), "empty message in queue?");
00618         ASSERT(req.connId, "null");
00619 
00620         // for UDP, need additional stuff...
00621         conn_rec_t * recTo = NULL;
00622         if (eType_UDPLocal == rec->type) {
00623                 // we must be sending to a remote UDP port
00624                 recTo = getConnectionRecord(req.connId);
00625                 ASSERT(recTo, "null entry for UDP receiver");
00626                 ASSERT(eType_UDPRemote == recTo->type,
00627                     "receiver is wrong type of connection");
00628         } else if (eType_UDPBroadcast == rec->type) {
00629                 // broadcast: we know where to send
00630                 recTo = rec;
00631         }
00632 
00633         // try to write header?
00634         if (rec->send_byte < 0) {
00635                 // yes, header has not yet been sent
00636 
00637                 // construct header and send
00638                 const int s_headerBytes = 32;
00639                 char header[s_headerBytes];
00640                 sprintf(header, "\ns %ld\n", req.msgbuf->getBytes());
00641                 //DPRINTF("Header:\n-----%s-----", header);
00642                 int hbytes = strlen(header);
00643                 //DPRINTF("About to send header...");
00644                 //DPRINTF(" hbytes = %d", hbytes);
00645                 //DPRINTF(" s = %d", rec->socket);
00646                 long bytes;
00647                 if (!recTo) {
00648                         bytes = wsSend(rec->socket, header, hbytes);
00649                 } else {
00650                         // recTo->address.dump("Sending here");
00651                         bytes = wsSendTo(rec->socket, header, hbytes,
00652                             recTo->address);
00653                 }
00654                 if (bytes < 0) {
00655                         // error!  bad or not?
00656                         if (eWS_Again == wsGetError()) {
00657                                 //DPRINTF("header send failed, will try again");
00658                                 return;         // quietly fail, try again later
00659                         }
00660                         DPRINTF("Error writing to client!  will close connection");
00661                         closeConnection(rec->conn_id);
00662                         return;
00663                 }
00664                 verifyThrow(bytes == hbytes, "failed to send complete header!");
00665                 s_bytesWritten += bytes;
00666 
00667                 // header successfully sent!
00668                 //DPRINTF("Successfully sent message header");
00669                 rec->send_byte = 0;     // no data sent yet
00670         }
00671 
00672         // keep sending...
00673         while (1) {
00674                 long to_send = req.msgbuf->getBytes() - rec->send_byte;
00675                 if (to_send < 1) {
00676                         //DPRINTF("Completed message send!");
00677 
00678                         // all done
00679                         s_messagesSent++;
00680                         break;
00681                 }
00682 
00683                 // don't send more than the reader can handle
00684                 if (to_send >= s_chunkSize)
00685                         to_send = s_chunkSize - 1;
00686 
00687                 const char * data = req.msgbuf->getData() + rec->send_byte;
00688                 long sent;
00689                 if (!recTo) {
00690                         sent = wsSend(rec->socket, data, to_send);
00691                 } else {
00692                         sent = wsSendTo(rec->socket, data, to_send,
00693                             recTo->address);
00694                 }
00695                 if (sent < 0) {
00696                         // error!  bad or not?
00697                         if (eWS_Again == wsGetError()) {
00698                                 DPRINTF("socket send() failed, will try again");
00699                                 return;
00700                         }
00701                         DPRINTF("Error writing to client!  will close connection");
00702                         closeConnection(rec->conn_id);
00703                         return;
00704                 }
00705                 s_bytesWritten += sent;
00706                 // DPRINTF("  sent %ld bytes", sent);
00707 
00708                 // update stats on bytes sent
00709                 rec->send_byte += sent;
00710         }
00711 
00712         // update socket data (can do this in place)
00713         rec->message_queue.pop_front();         // done with message!
00714         rec->send_byte = -1;                    // reset sent count
00715 }
00716 
00717 
00718 
00719 static void
00720 handleWrites
00721 (
00722 IN ws_set_t writeable
00723 )
00724 {
00725         // timer itself impacts timing!
00726         // perf::Timer timer("netlib::handleWrites");
00727         ASSERT(writeable, "null");
00728 
00729         for (conn_map_t::iterator i = s_connection_map.begin();
00730              i != s_connection_map.end(); ++i) {
00731                 conn_rec_t * rec = i->second;
00732                 ASSERT(rec, "null connection record");
00733                 int s = rec->socket;
00734                 if (!wsIsValidSocket(s))
00735                         continue;       // not a real socket!
00736 
00737                 if (!wsIsSocketInSet(writeable, s))
00738                         continue;       // socket isn't writeable
00739 
00740                 // DPRINTF("Have a message to write!");
00741                 if (!rec->message_queue.size())
00742                         continue;       // nothing to write anyway!
00743 
00744                 writeMessage(rec);
00745         }
00746 }
00747 
00748 
00749 
00750 static bool
00751 handleRead
00752 (
00753 IN ws_set_t readers,
00754 IO envelope_t& envelope,
00755 IO smart_ptr<MessageBuffer>& msgbuf
00756 )
00757 {
00758         // timer itself affects timing!
00759         // perf::Timer timer("netlib::handleRead");
00760         ASSERT(readers, "null");
00761         ASSERT(envelope.is_empty(), "not empty");
00762         ASSERT(!msgbuf, "should be null");
00763 
00764         for (conn_map_t::iterator i = s_connection_map.begin();
00765              i != s_connection_map.end(); ++i) {
00766                 conn_rec_t * rec = i->second;
00767                 ASSERT(rec, "null connection record");
00768                 int s = rec->socket;
00769                 if (!wsIsValidSocket(s))
00770                         continue;               // not a real socket
00771 
00772                 if (rec->buff_idx >= 0) {
00773                         // rec->dump("More data in buffer");
00774                         if (handleData(rec, envelope, msgbuf))
00775                                 return true;
00776                 }
00777 
00778                 if (!wsIsSocketInSet(readers, s))
00779                         continue;       // socket not impacted
00780 
00781                 if (eType_TCPListener == rec->type) {
00782                         // DPRINTF("Got a request to connect!");
00783                         handleConnection(rec);
00784                         return false;
00785                 } else {
00786                         // DPRINTF("Received data!");
00787                         // rec->dump("New incoming data");
00788                         if (handleData(rec, envelope, msgbuf))
00789                                 return true;
00790                 }
00791         }
00792 
00793         // can get here if no sockets were available for reads!
00794         return false;
00795 }
00796 
00797 
00798 
00799 ////////////////////////////////////////////////////////////////////////////////
00800 //
00801 //      public API
00802 //
00803 ////////////////////////////////////////////////////////////////////////////////
00804 
00805 std::string
00806 getServerFromIP
00807 (
00808 IN const ip_addr_t& ip
00809 )
00810 {
00811         // TODO: DNS reverse lookup!
00812         // for now, just string-ify the IP address
00813         ASSERT(1 == ip.flags, "Only works with IPv4 addresses for now...");
00814 
00815         char buffer[64];
00816         sprintf(buffer, "%d.%d.%d.%d",
00817             ip.addr[0], ip.addr[1], ip.addr[2], ip.addr[3]);
00818 
00819         return buffer;
00820 }
00821 
00822 
00823 
00824 conn_id_t
00825 createTcpListener
00826 (
00827 IN const address_t& address,
00828 IN int maxBacklog
00829 )
00830 {
00831         perf::Timer timer("netlib::createTcpListener");
00832         ASSERT(address.isValid(), "Invalid address");
00833         ASSERT(maxBacklog > 0, "bad max backlog: %d", maxBacklog);
00834 
00835         // set up listening socket
00836         int s = wsCreateTcpSocket();
00837         // DPRINTF("TCP listening socket: %d", s);
00838         verify(wsIsValidSocket(s), "Failed to create tcp listening socket");
00839 
00840         // bind (name) the socket
00841         verify(!wsBindToPort(s, address.port),
00842             "Failed to bind tcp listening socket");
00843 
00844         // set up for listening
00845         verify(!wsListen(s, maxBacklog),
00846             "Failed to set up tcp socket for listening");
00847 
00848         // add to our map
00849         return addConnectionRecord(eType_TCPListener, s, address);
00850 }
00851 
00852 
00853 
00854 conn_id_t
00855 createTcpConnection
00856 (
00857 IN const address_t& address
00858 )
00859 {
00860         perf::Timer timer("netlib::createTcpConnection");
00861         ASSERT2(address.isValid(),
00862             "invalid address--cannot create TCP connection");
00863 
00864         // create socket
00865         int c = wsCreateTcpSocket();
00866 
00867         // DPRINTF("TCP connection socket: %d", c);
00868         verifyThrow(wsIsValidSocket(c),
00869             "Failed to create socket for tcp connection");
00870 
00871         // set up address data
00872         verifyThrow(-1 != wsConnect(c, address),
00873             "Failed to connect to server");
00874 
00875 //      DPRINTF("Connected!");
00876         return addConnectionRecord(eType_TCP, c, address);
00877 }
00878 
00879 
00880 
00881 conn_id_t
00882 createUdpLocal
00883 (
00884 IN const address_t& address
00885 )
00886 {
00887         perf::Timer timer("netlib::createUdpLocal");
00888         ASSERT(address.isValid(), "invalid address");
00889 
00890         // create socket
00891         int c = wsCreateUdpSocket(false);       // not broadcast
00892         DPRINTF("Local UDP socket: %d", c);
00893         verifyThrow(wsIsValidSocket(c), "Failed to create socket for local udp");
00894 
00895         // bind to local port
00896         verifyThrow(-1 != wsBindToPort(c, address.port),
00897             "Failed to bind to local UDP socket");
00898 
00899         // all done!
00900         return addConnectionRecord(eType_UDPLocal, c, address);
00901 }
00902 
00903 
00904 
00905 conn_id_t
00906 createUdpRemote
00907 (
00908 IN conn_id_t localUdp,
00909 IN const address_t& address
00910 )
00911 {
00912         perf::Timer timer("netlib::createUdpRemote");
00913         ASSERT(localUdp, "null");
00914         ASSERT2(address.isValid(),
00915             "Address is invalid--cannot create remote UDP connection");
00916 
00917         // create connection record
00918         conn_id_t conn_id =
00919             addConnectionRecord(eType_UDPRemote, -1, address);
00920         ASSERT(conn_id, "null");
00921 
00922         // retrieve it because we're going to tweak it...
00923         conn_rec_t * rec = getConnectionRecord(conn_id);
00924         ASSERT(rec, "null");
00925         rec->local = localUdp;
00926         DPRINTF("Added remote udp connection");
00927 
00928         // all done!
00929         return conn_id;
00930 }
00931 
00932 
00933 
00934 conn_id_t
00935 createUdpBroadcast
00936 (
00937 IN const address_t& broadcastAddress
00938 )
00939 {
00940         ASSERT2(broadcastAddress.isValid(), "Invalid broadcast address");
00941 
00942         // create socket
00943         int s = wsCreateUdpSocket(true);        // yes, broadcast
00944         verifyThrow(wsIsValidSocket(s),
00945             "Failed to create socket for udp broadcast");
00946 
00947         // create connection record
00948         return addConnectionRecord(eType_UDPBroadcast, s, broadcastAddress);
00949 }
00950 
00951 
00952 
00953 bool
00954 enqueueMessage
00955 (
00956 IN conn_id_t conn_id,
00957 IN smart_ptr<MessageBuffer>& message
00958 )
00959 {
00960         // at the moment, this routine is fast enough that the timer itself
00961         //      adds significant time!
00962         // perf::Timer timer("netlib::enqueueMessage");
00963         ASSERT(conn_id, "null");
00964         ASSERT(message, "null");
00965 
00966         // look for this connection
00967         conn_rec_t * rec = getConnectionRecord(conn_id);
00968         if (!rec) {
00969                 DPRINTF("Connection id not recognized? 0x%04lx", (long) conn_id);
00970                 return false;
00971         }
00972         ASSERT(rec, "null record in map for connection id 0x%04lx", (long) conn_id);
00973         // DPRINTF("Enqueuing message for connection 0x%04lx", conn_id);
00974         // DPRINTF("  Local protocol is %d", rec->type);
00975 
00976         // improper connection type for sending?
00977         if (eType_UDPLocal == rec->type) {
00978                 DPRINTF("Cannot send messages to local UDP!  Send to remote");
00979                 return false;
00980         }
00981 
00982         // udp?  In that case, need to swap local + remote
00983         if (eType_UDPRemote == rec->type) {
00984                 ASSERT(rec->local, "null local UDP?");
00985                 conn_rec_t * localRec = getConnectionRecord(rec->local);
00986                 ASSERT(localRec, "local udp sender disappeared");
00987                 rec = localRec;         // swap out and queue here
00988         }
00989         ASSERT(wsIsValidSocket(rec->socket),
00990             "null socket in connection record for 0x%04lx", (long) conn_id);
00991 
00992         // add to message queue
00993         request_t req;
00994         req.connId = conn_id;
00995         req.msgbuf = message;
00996         rec->message_queue.push_back(req);
00997 
00998         // all done
00999         return true;
01000 }
01001 
01002 
01003 
01004 
01005 bool
01006 getNextMessage
01007 (
01008 IN long wait_microseconds,
01009 OUT envelope_t& envelope,
01010 OUT smart_ptr<MessageBuffer>& msgbuf
01011 )
01012 {
01013         perf::Timer timer("netlib::getNextMessage");
01014         ASSERT(wait_microseconds >= 0, "Bad wait: %ld", wait_microseconds);
01015         envelope.clear();
01016         ASSERT(!msgbuf, "unecessary free?");
01017 
01018         // arbitrary assertion here
01019         ASSERT(wait_microseconds < 1000 * 1000, "Wait is too long! %ld usec",
01020             wait_microseconds);
01021 
01022         // mask of types that can read
01023         dword_t readMask = eType_TCP | eType_UDPLocal | eType_TCPListener;
01024 
01025         // see if any sockets have data for us
01026         static ws_set_t readers = 0;
01027         if (!readers) {
01028                 readers = wsCreateSet();
01029         }
01030         wsClearSet(readers);
01031 
01032         // also see if any sockets are ready to be sent out on
01033         static ws_set_t writeable = 0;
01034         if (!writeable) {
01035                 writeable = wsCreateSet();
01036         }
01037         wsClearSet(writeable);
01038 
01039         // add all sockets we care about
01040         int max = 0;
01041         for (conn_map_t::iterator i = s_connection_map.begin();
01042              i != s_connection_map.end(); ++i) {
01043                 const conn_rec_t * rec = i->second;
01044                 ASSERT(eType_Invalid != rec->type, "invalid connection type?");
01045 
01046                 // only read from certain sockets
01047                 if (readMask & rec->type) {
01048                         wsAddSocketToSet(readers, rec->socket);
01049                 }
01050 
01051                 // interested in writing if message is pending
01052                 if (rec->message_queue.size() > 0) {
01053                         wsAddSocketToSet(writeable, rec->socket);
01054                 }
01055 
01056                 // update max socket ID?
01057                 if (rec->socket > max)
01058                         max = rec->socket;
01059         }
01060 
01061         // go get 'em
01062         int count = wsSelect(max + 1, readers, writeable, wait_microseconds);
01063         if (count < 0) {
01064                 dumpErrorInfo("select() call failed");
01065                 return false;
01066         }
01067 
01068         // first see if anything can be written
01069         handleWrites(writeable);
01070 
01071         // okay, see if there is anything to read!
01072         return handleRead(readers, envelope, msgbuf);
01073 }
01074 
01075 
01076 
01077 bool
01078 isValidConnection
01079 (
01080 IN conn_id_t conn_id
01081 )
01082 {
01083         return (NULL != getConnectionRecord(conn_id));
01084 }
01085 
01086 
01087 
01088 bool
01089 getConnectionInfo
01090 (
01091 IN conn_id_t conn_id,
01092 OUT connection_info_t& info
01093 )
01094 {
01095         ASSERT(conn_id, "null");
01096         info.clear();
01097 
01098         conn_rec_t * rec = getConnectionRecord(conn_id);
01099         if (!rec) {
01100                 return false;
01101         }
01102 
01103         info.type = rec->type;
01104         info.address = rec->address;
01105 
01106         // all done
01107         return true;
01108 }
01109 
01110 
01111 
01112 void
01113 closeConnection
01114 (
01115 IN conn_id_t conn_id
01116 )
01117 {
01118         ASSERT(conn_id, "null");
01119 
01120         DPRINTF("Closing connection 0x%lx...", (long) conn_id);
01121 
01122         conn_map_t::iterator i = s_connection_map.find(conn_id);
01123         if (s_connection_map.end() == i) {
01124                 DPRINTF("Error in closeConnection() -- connection id 0x%lx not found",
01125                     (long) conn_id);
01126                 return;
01127         }
01128 
01129         // remove the connection
01130         s_connection_map.erase(i);
01131 }
01132 
01133 
01134 
01135 void
01136 dumpMessage
01137 (
01138 IO std::ostream& stream,
01139 IN const char * title,
01140 IN const envelope_t& envelope,
01141 IN const MessageBuffer * buffer
01142 )
01143 {
01144         ASSERT(stream.good(), "bad?");
01145         ASSERT(title, "null");
01146 
01147         DPRINTF("Message dump: %s", title);
01148         if (envelope.is_empty()) {
01149                 DPRINTF("  Envelope is empty!");
01150         } else {
01151                 DPRINTF("  From: 0x%04lx", (long) envelope.fromConnId);
01152                 DPRINTF("  Connection Type: %d (%s)", envelope.type,
01153                     getTypeName(envelope.type));
01154                 envelope.address.dump(title);
01155         }
01156 
01157         if (!buffer) {
01158                 DPRINTF("  Null message buffer!");
01159         } else {
01160                 DPRINTF("  Message: '%s'", buffer->getData());
01161         }
01162 }
01163 
01164 
01165 
01166 void
01167 dumpStats
01168 (
01169 void
01170 )
01171 {
01172         DPRINTF("Networking stats:");
01173         DPRINTF("  Total messages received: %6u", s_messagesReceived);
01174         DPRINTF("  Total messages sent:     %6u", s_messagesSent);
01175 //#ifdef WIN32
01176 //      // sigh... Windows has a non-standard format for long long
01177 //      DPRINTF("  Total bytes read:    %10I64d  (%4I64d MB)", s_bytesRead,
01178 //          (s_bytesRead + 512 * 1024) / (1024 * 1024));
01179 //      DPRINTF("  Total bytes written: %10I64d  (%4I64d MB)", s_bytesWritten,
01180 //          (s_bytesWritten + 512 * 1024) / (1024 * 1024));
01181 //#else // WIN32
01182         std::cerr << "  Total bytes received: " << s_bytesRead;
01183         std::cerr << "  (" << (s_bytesRead + 512 * 1024) / (1024 * 1024) << " MB)\n";
01184         std::cerr << "  Total bytes sent: " << s_bytesWritten;
01185         std::cerr << "  (" << (s_bytesWritten + 512 * 1024) / (1024 * 1024) << " MB)\n";
01186 //#endif        // WIN32
01187 
01188         if (s_messagesReceived > 0) {
01189                 long avg = (long) (s_bytesRead / s_messagesReceived);
01190                 DPRINTF("  Average size of message read: %5ld bytes", avg);
01191         }
01192         if (s_messagesSent > 0) {
01193                 long avg = (long) (s_bytesWritten / s_messagesSent);
01194                 DPRINTF("  Average size of message sent: %5ld bytes", avg);
01195         }
01196 }
01197 
01198 
01199 void
01200 connection_info_t::dump
01201 (
01202 IN const char * title
01203 )
01204 const
01205 throw()
01206 {
01207         DPRINTF("Connection info: %s", title);
01208         DPRINTF("  type: %d", type);
01209         address.dump(title);
01210 }
01211 
01212 
01213 };      // netlib namespace
01214